leaking grid aggregations

This commit is contained in:
Tobias Hölzer 2025-11-24 22:40:17 +01:00
parent 18cc1b8601
commit 341fa7b836
5 changed files with 521 additions and 154 deletions

85
pixi.lock generated
View file

@ -543,6 +543,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/ab/b5/36c712098e6191d1b4e349304ef73a8d06aed77e56ceaac8c0a306c7bda1/jupyterlab_widgets-3.0.16-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c4/bd/ba44a47578ea48ee28b54543c1de8c529eedad8317516a2a753e6d9c77c5/lonboard-0.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/52/40/617b15e62d5de1718e81ee436a1f19d4d40274ead97ac0eda188baebb986/memray-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl
@ -594,6 +595,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/39/60/868371b6482ccd9ef423c6f62650066cf8271fdb2ee84f192695ad6b7a96/streamlit-1.51.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/53/b3/95ab646b0c908823d71e49ab8b5949ec9f33346cee3897d1af6be28a8d91/textual-6.6.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl
@ -2808,7 +2810,7 @@ packages:
- pypi: ./
name: entropice
version: 0.1.0
sha256: d95c691c76206bf54e207fe02b100a247a3847f37135d1cdf6ee18165770ea46
sha256: 788df3ea7773f54fce8274d57d51af8edbab106c5b2082f8efca6639aa3eece9
requires_dist:
- aiohttp>=3.12.11
- bokeh>=3.7.3
@ -2862,6 +2864,7 @@ packages:
- s3fs>=2025.10.0,<2026
- xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial
- cupy-xarray>=0.1.4,<0.2
- memray>=1.19.1,<2
requires_python: '>=3.13,<3.14'
editable: true
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
@ -6344,6 +6347,58 @@ packages:
- pkg:pypi/mdurl?source=hash-mapping
size: 14465
timestamp: 1733255681319
- pypi: https://files.pythonhosted.org/packages/52/40/617b15e62d5de1718e81ee436a1f19d4d40274ead97ac0eda188baebb986/memray-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
name: memray
version: 1.19.1
sha256: 9770b6046a8b7d7bb982c2cefb324a60a0ee7cd7c35c58c0df058355a9a54676
requires_dist:
- jinja2>=2.9
- typing-extensions ; python_full_version < '3.8'
- rich>=11.2.0
- textual>=0.41.0
- cython ; extra == 'test'
- greenlet ; python_full_version < '3.14' and extra == 'test'
- pytest ; extra == 'test'
- pytest-cov ; extra == 'test'
- ipython ; extra == 'test'
- setuptools ; extra == 'test'
- pytest-textual-snapshot ; extra == 'test'
- textual>=0.43,!=0.65.2,!=0.66 ; extra == 'test'
- packaging ; extra == 'test'
- ipython ; extra == 'docs'
- bump2version ; extra == 'docs'
- sphinx ; extra == 'docs'
- furo ; extra == 'docs'
- sphinx-argparse ; extra == 'docs'
- towncrier ; extra == 'docs'
- black ; extra == 'lint'
- flake8 ; extra == 'lint'
- isort ; extra == 'lint'
- mypy ; extra == 'lint'
- check-manifest ; extra == 'lint'
- asv ; extra == 'benchmark'
- cython ; extra == 'dev'
- greenlet ; python_full_version < '3.14' and extra == 'dev'
- pytest ; extra == 'dev'
- pytest-cov ; extra == 'dev'
- ipython ; extra == 'dev'
- setuptools ; extra == 'dev'
- pytest-textual-snapshot ; extra == 'dev'
- textual>=0.43,!=0.65.2,!=0.66 ; extra == 'dev'
- packaging ; extra == 'dev'
- black ; extra == 'dev'
- flake8 ; extra == 'dev'
- isort ; extra == 'dev'
- mypy ; extra == 'dev'
- check-manifest ; extra == 'dev'
- ipython ; extra == 'dev'
- bump2version ; extra == 'dev'
- sphinx ; extra == 'dev'
- furo ; extra == 'dev'
- sphinx-argparse ; extra == 'dev'
- towncrier ; extra == 'dev'
- asv ; extra == 'dev'
requires_python: '>=3.7.0'
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
name: mercantile
version: 1.2.1
@ -8866,6 +8921,34 @@ packages:
- pkg:pypi/terminado?source=hash-mapping
size: 22452
timestamp: 1710262728753
- pypi: https://files.pythonhosted.org/packages/53/b3/95ab646b0c908823d71e49ab8b5949ec9f33346cee3897d1af6be28a8d91/textual-6.6.0-py3-none-any.whl
name: textual
version: 6.6.0
sha256: 5a9484bd15ee8a6fd8ac4ed4849fb25ee56bed2cecc7b8a83c4cd7d5f19515e5
requires_dist:
- markdown-it-py[linkify]>=2.1.0
- mdit-py-plugins
- platformdirs>=3.6.0,<5
- pygments>=2.19.2,<3.0.0
- rich>=14.2.0
- tree-sitter>=0.25.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-bash>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-css>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-go>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-html>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-java>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-javascript>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-json>=0.24.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-markdown>=0.3.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-python>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-regex>=0.24.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-rust>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-sql>=0.3.11 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-toml>=0.6.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-xml>=0.7.0 ; python_full_version >= '3.10' and extra == 'syntax'
- tree-sitter-yaml>=0.6.0 ; python_full_version >= '3.10' and extra == 'syntax'
- typing-extensions>=4.4.0,<5.0.0
requires_python: '>=3.9,<4.0'
- conda: https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda
sha256: 6016672e0e72c4cf23c0cf7b1986283bd86a9c17e8d319212d78d8e9ae42fdfd
md5: 9d64911b31d57ca443e9f1e36b04385f

View file

@ -57,7 +57,7 @@ dependencies = [
"xgboost>=3.1.1,<4",
"s3fs>=2025.10.0,<2026",
"xarray-spatial",
"cupy-xarray>=0.1.4,<0.2",
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2",
]
[project.scripts]

View file

@ -1,16 +1,27 @@
"""Aggregation helpers."""
import gc
import os
import sys
from collections import defaultdict
from collections.abc import Callable, Generator
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from functools import cache
from typing import Literal
import cupy_xarray
import geopandas as gpd
import numpy as np
import odc.geo.geobox
import pandas as pd
import psutil
import shapely
import shapely.ops
import sklearn
import xarray as xr
import xdggs
import xvec
from rich import print
from rich.progress import track
from shapely.geometry import LineString, Polygon
from stopuhr import stopwatch
@ -19,34 +30,100 @@ from xdggs.healpix import HealpixInfo
from entropice import grids
@dataclass
@dataclass(frozen=True)
class _Aggregations:
# ! The ordering is super important for this class!
mean: bool = True
sum: bool = False
std: bool = False
min: bool = False
max: bool = False
quantiles: list[float] = field(default_factory=lambda: [])
median: bool = False # Will be mapped as quantile 0.5
quantiles: tuple[float] = field(default_factory=tuple)
def varnames(self, vars: list[str] | str) -> list[str]:
if isinstance(vars, str):
vars = [vars]
agg_vars = []
for var in vars:
def __post_init__(self):
assert isinstance(self.quantiles, tuple), "Quantiles must be a tuple, ortherwise the class is not hashable"
if self.median and 0.5 in self.quantiles:
raise ValueError("Median aggregation cannot be used together with quantile 0.5")
@cache
def __len__(self) -> int:
length = 0
if self.mean:
agg_vars.append(f"{var}_mean")
length += 1
if self.sum:
agg_vars.append(f"{var}_sum")
length += 1
if self.std:
agg_vars.append(f"{var}_std")
length += 1
if self.min:
agg_vars.append(f"{var}_min")
length += 1
if self.max:
agg_vars.append(f"{var}_max")
for q in self.quantiles:
length += 1
if self.median and 0.5 not in self.quantiles:
length += 1
length += len(self.quantiles)
return length
def aggnames(self) -> list[str]:
names = []
if self.mean:
names.append("mean")
if self.sum:
names.append("sum")
if self.std:
names.append("std")
if self.min:
names.append("min")
if self.max:
names.append("max")
if self.median:
names.append("median")
for q in sorted(self.quantiles): # ? Ordering is important here!
q_int = int(q * 100)
agg_vars.append(f"{var}_p{q_int}")
return agg_vars
names.append(f"p{q_int}")
return names
def _agg_cell_data_single(self, flattened_var: xr.DataArray) -> np.ndarray:
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
j = 0
if self.mean:
cell_data[j, ...] = flattened_var.mean(dim="z", skipna=True).to_numpy()
j += 1
if self.sum:
cell_data[j, ...] = flattened_var.sum(dim="z", skipna=True).to_numpy()
j += 1
if self.std:
cell_data[j, ...] = flattened_var.std(dim="z", skipna=True).to_numpy()
j += 1
if self.min:
cell_data[j, ...] = flattened_var.min(dim="z", skipna=True).to_numpy()
j += 1
if self.max:
cell_data[j, ...] = flattened_var.max(dim="z", skipna=True).to_numpy()
j += 1
if len(self.quantiles) > 0 or self.median:
quantiles_to_compute = sorted(self.quantiles) # ? Ordering is important here!
if self.median:
quantiles_to_compute.insert(0, 0.5)
cell_data[j:, ...] = flattened_var.quantile(
q=quantiles_to_compute,
dim="z",
skipna=True,
).to_numpy()
def agg_cell_data(self, flattened: xr.Dataset | xr.DataArray) -> np.ndarray:
if isinstance(flattened, xr.DataArray):
return self._agg_cell_data_single(flattened)
others_shape = tuple([flattened.sizes[dim] for dim in flattened.dims if dim != "z"])
cell_data = np.full((len(flattened.data_vars), len(self), *others_shape), np.nan, dtype=np.float32)
for i, var in enumerate(flattened.data_vars):
cell_data[i, ...] = self._agg_cell_data_single(flattened[var])
# Transform to numpy arrays
# if flattened.cupy.is_cupy:
# cell_data = cell_data.get()
return cell_data
def _crosses_antimeridian(geom: Polygon) -> bool:
@ -77,11 +154,11 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
@stopwatch("Getting corrected geometries", log=False)
@stopwatch("Correcting geometries", log=False)
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]:
geom, gbox, crs = inp
# cell.geometry is a shapely Polygon
if not _crosses_antimeridian(geom):
if crs != "EPSG:4326" or not _crosses_antimeridian(geom):
geoms = [geom]
# Split geometry in case it crossed antimeridian
else:
@ -92,153 +169,338 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis
return geoms
@stopwatch("Correcting geometries")
def get_corrected_geometries(grid_gdf: gpd.GeoDataFrame, gbox: odc.geo.geobox.GeoBox):
"""Get corrected geometries for antimeridian-crossing polygons.
Args:
grid_gdf (gpd.GeoDataFrame): Grid GeoDataFrame.
gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference.
Returns:
list[list[odc.geo.Geometry]]: List of corrected, georeferenced geometries.
"""
with ProcessPoolExecutor(max_workers=20) as executor:
cell_geometries = list(
executor.map(_get_corrected_geoms, [(row.geometry, gbox, grid_gdf.crs) for _, row in grid_gdf.iterrows()])
)
return cell_geometries
@stopwatch("Aggregating cell data", log=False)
def _agg_cell_data(flattened: xr.Dataset, aggregations: _Aggregations):
cell_data = {}
for var in flattened.data_vars:
if aggregations.mean:
cell_data[f"{var}_mean"] = flattened[var].mean(dim="z", skipna=True)
if aggregations.sum:
cell_data[f"{var}_sum"] = flattened[var].sum(dim="z", skipna=True)
if aggregations.std:
cell_data[f"{var}_std"] = flattened[var].std(dim="z", skipna=True)
if aggregations.min:
cell_data[f"{var}_min"] = flattened[var].min(dim="z", skipna=True)
if aggregations.max:
cell_data[f"{var}_max"] = flattened[var].max(dim="z", skipna=True)
if len(aggregations.quantiles) > 0:
quantile_values = flattened[var].quantile(
q=aggregations.quantiles,
dim="z",
skipna=True,
)
for q, qv in zip(aggregations.quantiles, quantile_values):
q_int = int(q * 100)
cell_data[f"{var}_p{q_int}"] = qv
# Transform to numpy arrays
for key in cell_data:
cell_data[key] = cell_data[key].cupy.as_numpy().values
return cell_data
@stopwatch("Extracting cell data", log=False)
def _extract_cell_data(ds: xr.Dataset, geom: odc.geo.Geometry, aggregations: _Aggregations):
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
cropped: xr.Dataset = ds.odc.crop(geom).drop_vars("spatial_ref")
flattened = cropped.stack(z=spatdims)
if flattened.z.size > 3000:
flattened = flattened.cupy.as_cupy()
cell_data = _agg_cell_data(flattened, aggregations)
def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
spatdims = ["latitude", "longitude"] if "latitude" in cropped.dims and "longitude" in cropped.dims else ["y", "x"]
flattened = cropped.stack(z=spatdims) # noqa: PD013
# if flattened.z.size > 3000:
# flattened = flattened.cupy.as_cupy()
cell_data = aggregations.agg_cell_data(flattened)
return cell_data
# with np.errstate(divide="ignore", invalid="ignore"):
# cell_data = cropped.mean(spatdims)
# return {var: cell_data[var].values for var in cell_data.data_vars}
@stopwatch("Extracting split cell data", log=False)
def _extract_split_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry], aggregations: _Aggregations):
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
cropped: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(ds, geoms)]
flattened = xr.concat([c.stack(z=spatdims) for c in cropped], dim="z")
cell_data = _agg_cell_data(flattened, aggregations)
def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggregations: _Aggregations):
spatdims = (
["latitude", "longitude"]
if "latitude" in cropped_list[0].dims and "longitude" in cropped_list[0].dims
else ["y", "x"]
)
flattened = xr.concat([c.stack(z=spatdims) for c in cropped_list], dim="z") # noqa: PD013
cell_data = aggregations.agg_cell_data(flattened)
return cell_data
# partial_counts = [part.notnull().sum(dim=spatdims) for part in parts]
# with np.errstate(divide="ignore", invalid="ignore"):
# partial_means = [part.sum(dim=spatdims) for part in parts]
# n = xr.concat(partial_counts, dim="part").sum("part")
# cell_data = xr.concat(partial_means, dim="part").sum("part") / n
# return {var: cell_data[var].values for var in cell_data.data_vars}
@stopwatch("Cropping (and reading?) cell data", log=False)
def _read_cell_data(unaligned: xr.Dataset | xr.DataArray, geom: odc.geo.Geometry) -> xr.Dataset:
cropped: xr.Dataset = unaligned.odc.crop(geom, apply_mask=True).drop_vars("spatial_ref")
return cropped.compute()
@stopwatch("Cropping (and reading?) split cell data", log=False)
def _read_split_cell_data(unaligned: xr.Dataset | xr.DataArray, geoms: list[odc.geo.Geometry]) -> list[xr.Dataset]:
cropped: list[xr.Dataset] = [unaligned.odc.crop(geom, apply_mask=True).drop_vars("spatial_ref") for geom in geoms]
return [c.compute() for c in cropped]
@stopwatch("Processing cell", log=False)
def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
geoms = _get_corrected_geoms((poly, unaligned.odc.geobox, unaligned.odc.crs))
if len(geoms) == 0:
raise ValueError("No geometries to process")
elif len(geoms) == 1:
cropped = _read_cell_data(unaligned, geoms[0])
cell_data = _extract_cell_data(cropped, aggregations)
else:
cropped = _read_split_cell_data(unaligned, geoms)
cell_data = _extract_split_cell_data(cropped, aggregations)
return cell_data
def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
# ! TODO: Adjust for antimeridian crossing, this is only needed if CRS is not 3413
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
for i in range(n_partitions):
partition = grid_gdf[labels == i]
yield partition
class _MemoryProfiler:
def __init__(self):
self.process = psutil.Process()
self._logs = defaultdict(list)
def log_memory(self, message: str, log: bool = True):
mem = self.process.memory_info().rss / 1024**3
self._logs[message].append(mem)
if log:
print(f"[MemoryProfiler] {message}: {mem:.2f} GB")
def summary(self):
summary_lines = ["[MemoryProfiler] Memory usage summary:"]
for message, mems in self._logs.items():
mems_str = [f"{m:.2f}" for m in mems]
summary_lines.append(f" {message}: min={min(mems):.2f} GB, max={max(mems):.2f} GB")
summary_lines.append(f" Measurements: {', '.join(mems_str)}")
return "\n".join(summary_lines)
memprof = None
def _init_worker():
global memprof
memprof = _MemoryProfiler()
def _align_partition(
grid_partition_gdf: gpd.GeoDataFrame,
raster: xr.Dataset | Callable[[], xr.Dataset],
aggregations: _Aggregations,
pxbuffer: int,
):
# Strategy for each cell:
# 1. Correct the geometry to account for the antimeridian
# 2. Cop the dataset and load the data into memory
# 3. Aggregate the data
# 4. Store the aggregated data in the correct location in the output arrays
# Steps 1-3 are done in different Processes, since all of these steps are CPU bound
# Even the reading part is CPU bound, since it involves applying masks and cropping the data
# Also, the backend (icechunk or zarr) can do some computations, like combining chunks together
# For the Arcticdem on Hex3 grid the avg. time per task is:
# 1. ~0.01s
# 2. ~1-3s
# 3. ~10-20s
# And for Hex6:
# 1. unmeasurable
# 2. ~0.1-0.3s
# 3. ~0.05s
memprof.log_memory("Before reading partial raster", log=False)
if callable(raster) and not isinstance(raster, xr.Dataset):
raster = raster()
if hasattr(raster, "__dask_graph__"):
graph_size = len(raster.__dask_graph__())
graph_memory = sys.getsizeof(raster.__dask_graph__()) / 1024**2
print(f"Raster graph size: {graph_size} tasks, {graph_memory:.2f} MB")
others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]])
data_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
data = np.full(data_shape, np.nan, dtype=np.float32)
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
partial_extent = partial_extent.buffered(
raster.odc.geobox.resolution.x * pxbuffer,
raster.odc.geobox.resolution.y * pxbuffer,
) # buffer by pxbuffer pixels
with stopwatch("Cropping raster to partition extent", log=False):
try:
partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute()
except Exception as e:
print(f"Error cropping raster to partition extent: {e}")
return data
print(f"{os.getpid()}: Partial raster size: {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)")
memprof.log_memory("After reading partial raster", log=True)
with ProcessPoolExecutor(
max_workers=60,
) as executor:
futures = {}
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
# ? Splitting the geometries already here and passing only the cropped data to the worker to
# reduce pickling overhead
geoms = _get_corrected_geoms((row.geometry, partial_raster.odc.geobox, partial_raster.odc.crs))
if len(geoms) == 0:
continue
elif len(geoms) == 1:
cropped = [_read_cell_data(partial_raster, geoms[0])]
else:
cropped = _read_split_cell_data(partial_raster, geoms)
futures[executor.submit(_extract_split_cell_data, cropped, aggregations)] = i
# Clean up immediately after submitting
del geoms, cropped # Added this line
for future in as_completed(futures):
i = futures[future]
try:
cell_data = future.result()
except (SystemError, SystemExit, KeyboardInterrupt) as e:
raise e
except Exception as e:
print(f"Error processing cell {i}: {e}")
continue
data[i, ...] = cell_data
del cell_data, futures[future]
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
# try:
# cell_data = _process_geom(row.geometry, partial_raster, aggregations)
# except (SystemError, SystemExit, KeyboardInterrupt) as e:
# raise e
# except Exception as e:
# print(f"Error processing cell {row['cell_id']}: {e}")
# continue
# data[i, ...] = cell_data
partial_raster.close()
raster.close()
del partial_raster, raster, future, futures
gc.collect()
memprof.log_memory("After cleaning", log=True)
print("Finished processing partition")
print("### Stopwatch summary ###\n")
print(stopwatch.summary())
print("### Memory summary ###\n")
print(memprof.summary())
print("#########################")
# with ProcessPoolExecutor(
# max_workers=max_workers,
# initializer=_init_partial_raster_global,
# initargs=(raster, grid_partition_gdf, pxbuffer),
# ) as executor:
# futures = {
# executor.submit(_process_geom_global, row.geometry, aggregations): i
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows())
# }
# others_shape = tuple(
# [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
# )
# data_shape = (len(futures), len(raster.data_vars), len(aggregations), *others_shape)
# print(f"Allocating partial data array of shape {data_shape}")
# data = np.full(data_shape, np.nan, dtype=np.float32)
# lap = time.time()
# for future in track(
# as_completed(futures),
# total=len(futures),
# description="Spatially aggregating ERA5 data...",
# ):
# print(f"time since last cell: {time.time() - lap:.2f}s")
# i = futures[future]
# try:
# cell_data = future.result()
# data[i, ...] = cell_data
# except Exception as e:
# print(f"Error processing cell {i}: {e}")
# continue
# finally:
# # Try to free memory
# del futures[future], future
# if "cell_data" in locals():
# del cell_data
# gc.collect()
# lap = time.time()
return data
@stopwatch("Aligning data with grid")
def _align_data(
grid_gdf: gpd.GeoDataFrame,
unaligned: xr.Dataset,
raster: xr.Dataset | Callable[[], xr.Dataset],
aggregations: _Aggregations,
) -> dict[str, np.ndarray]:
# Persist the dataset, as all the operations here MUST NOT be lazy
unaligned = unaligned.load()
n_partitions: int,
max_workers: int,
pxbuffer: int,
):
partial_data = {}
cell_geometries = get_corrected_geometries(grid_gdf, unaligned.odc.geobox)
other_dims_shape = tuple(
[unaligned.sizes[dim] for dim in unaligned.dims if dim not in ["y", "x", "latitude", "longitude"]]
)
data_shape = (len(cell_geometries), *other_dims_shape)
data = {var: np.full(data_shape, np.nan, dtype=np.float32) for var in aggregations.varnames(unaligned.data_vars)}
with ProcessPoolExecutor(max_workers=10) as executor:
with ProcessPoolExecutor(
max_workers=max_workers,
initializer=_init_worker,
# initializer=_init_raster_global,
# initargs=(raster,),
) as executor:
futures = {}
for i, geoms in track(
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
):
if len(geoms) == 0:
continue
elif len(geoms) == 1:
geom = geoms[0]
# Reduce the amount of data needed to be sent to the worker
# Since we dont mask the data, only isel operations are done here
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
fut = executor.submit(_extract_cell_data, unaligned_subset, geom, aggregations)
else:
# Same as above but for multiple parts
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
fut = executor.submit(_extract_split_cell_data, unaligned_subsets, geoms, aggregations)
futures[fut] = i
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
futures[
executor.submit(
_align_partition,
grid_partition,
raster.copy() if isinstance(raster, xr.Dataset) else raster,
aggregations,
pxbuffer,
)
] = i
print("Submitted all partitions, waiting for results...")
for future in track(
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
as_completed(futures),
total=len(futures),
description="Processing grid partitions...",
):
try:
i = futures[future]
cell_data = future.result()
for var in unaligned.data_vars:
data[var][i, :] = cell_data[var]
print(f"Processed partition {i + 1}/{len(futures)}")
part_data = future.result()
partial_data[i] = part_data
except Exception as e:
print(f"Error processing partition {i}: {e}")
raise e
data = np.concatenate([partial_data[i] for i in range(len(partial_data))], axis=0)
return data
def aggregate_raster_into_grid(
raster: xr.Dataset,
raster: xr.Dataset | Callable[[], xr.Dataset],
grid_gdf: gpd.GeoDataFrame,
aggregations: _Aggregations,
grid: Literal["hex", "healpix"],
level: int,
n_partitions: int = 20,
max_workers: int = 5,
pxbuffer: int = 15,
):
aligned = _align_data(grid_gdf, raster, aggregations)
"""Aggregate raster data into grid cells.
Args:
raster (xr.Dataset | Callable[[], xr.Dataset]): Raster data or a function that returns it.
grid_gdf (gpd.GeoDataFrame): The grid to aggregate into.
aggregations (_Aggregations): The aggregations to perform.
grid (Literal["hex", "healpix"]): The type of grid to use.
level (int): The level of the grid.
n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20.
max_workers (int, optional): Maximum number of worker processes. Defaults to 5.
pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15.
Returns:
xr.Dataset: Aggregated data aligned with the grid.
"""
aligned = _align_data(
grid_gdf,
raster,
aggregations,
n_partitions=n_partitions,
max_workers=max_workers,
pxbuffer=pxbuffer,
)
# Dims of aligned: (cell_ids, variables, aggregations, other dims...)
if callable(raster) and not isinstance(raster, xr.Dataset):
raster = raster()
cell_ids = grids.get_cell_ids(grid, level)
dims = ["cell_ids"]
coords = {"cell_ids": cell_ids}
for dim in raster.dims:
if dim not in ["y", "x", "latitude", "longitude"]:
dims = ["cell_ids", "variables", "aggregations"]
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
dims.append(dim)
coords[dim] = raster.coords[dim]
data_vars = {var: (dims, values) for var, values in aligned.items()}
ongrid = xr.Dataset(
data_vars,
coords=coords,
)
ongrid = xr.DataArray(aligned, dims=dims, coords=coords).to_dataset("variables")
gridinfo = {
"grid_name": "h3" if grid == "hex" else grid,
"level": level,
@ -247,8 +509,7 @@ def aggregate_raster_into_grid(
gridinfo["indexing_scheme"] = "nested"
ongrid.cell_ids.attrs = gridinfo
for var in raster.data_vars:
for v in aggregations.varnames(var):
ongrid[v].attrs = raster[var].attrs
ongrid[var].attrs = raster[var].attrs
ongrid = xdggs.decode(ongrid)
return ongrid

View file

@ -77,7 +77,7 @@ def download(grid: Literal["hex", "healpix"], level: int):
return feature.set(mean_dict)
# Process grid in batches of 100
batch_size = 100
batch_size = 50
all_results = []
n_batches = len(grid_gdf) // batch_size
for batch_num, batch_grid in track(

View file

@ -1,4 +1,5 @@
import datetime
import multiprocessing as mp
from dataclasses import dataclass
from math import ceil
from typing import Literal
@ -9,6 +10,7 @@ import cupyx.scipy.signal
import cyclopts
import dask.array
import dask.distributed as dd
import icechunk
import icechunk.xarray
import numpy as np
import smart_geocubes
@ -301,18 +303,23 @@ def enrich():
print("Enrichment complete.")
@cli.command()
def aggregate(grid: Literal["hex", "healpix"], level: int):
with stopwatch(f"Loading {grid} grid at level {level}"):
grid_gdf = grids.open(grid, level)
# ? Mask out water, since we don't want to aggregate over oceans
grid_gdf = watermask.clip_grid(grid_gdf)
def _open_adem():
arcticdem_store = get_arcticdem_stores()
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
adem = accessor.open_xarray()
assert {"x", "y"} == set(adem.dims)
assert adem.odc.crs == "EPSG:3413"
return adem
@cli.command()
def aggregate(grid: Literal["hex", "healpix"], level: int, max_workers: int = 20):
mp.set_start_method("forkserver", force=True)
with stopwatch(f"Loading {grid} grid at level {level}"):
grid_gdf = grids.open(grid, level)
# ? Mask out water, since we don't want to aggregate over oceans
grid_gdf = watermask.clip_grid(grid_gdf)
aggregations = _Aggregations(
mean=True,
@ -320,13 +327,29 @@ def aggregate(grid: Literal["hex", "healpix"], level: int):
std=True,
min=True,
max=True,
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
median=True,
quantiles=(0.01, 0.05, 0.25, 0.75, 0.95, 0.99),
)
aggregated = aggregate_raster_into_grid(
_open_adem,
grid_gdf,
aggregations,
grid,
level,
max_workers=max_workers,
n_partitions=200, # Results in 10GB large patches
pxbuffer=30,
)
aggregated = aggregate_raster_into_grid(adem, grid_gdf, aggregations, grid, level)
store = get_arcticdem_stores(grid, level)
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
print("Aggregation complete.")
print("### Finished ArcticDEM processing ###")
stopwatch.summary()
if __name__ == "__main__":
cli()