From 341fa7b836a3fd53adb0fedb3805e352f75ab749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Mon, 24 Nov 2025 22:40:17 +0100 Subject: [PATCH] leaking grid aggregations --- pixi.lock | 85 +++++- pyproject.toml | 2 +- src/entropice/aggregators.py | 545 ++++++++++++++++++++++++++--------- src/entropice/alphaearth.py | 2 +- src/entropice/arcticdem.py | 41 ++- 5 files changed, 521 insertions(+), 154 deletions(-) diff --git a/pixi.lock b/pixi.lock index 4021421..aefb099 100644 --- a/pixi.lock +++ b/pixi.lock @@ -543,6 +543,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/ab/b5/36c712098e6191d1b4e349304ef73a8d06aed77e56ceaac8c0a306c7bda1/jupyterlab_widgets-3.0.16-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c4/bd/ba44a47578ea48ee28b54543c1de8c529eedad8317516a2a753e6d9c77c5/lonboard-0.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/52/40/617b15e62d5de1718e81ee436a1f19d4d40274ead97ac0eda188baebb986/memray-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl @@ -594,6 +595,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/39/60/868371b6482ccd9ef423c6f62650066cf8271fdb2ee84f192695ad6b7a96/streamlit-1.51.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/53/b3/95ab646b0c908823d71e49ab8b5949ec9f33346cee3897d1af6be28a8d91/textual-6.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl @@ -2808,7 +2810,7 @@ packages: - pypi: ./ name: entropice version: 0.1.0 - sha256: d95c691c76206bf54e207fe02b100a247a3847f37135d1cdf6ee18165770ea46 + sha256: 788df3ea7773f54fce8274d57d51af8edbab106c5b2082f8efca6639aa3eece9 requires_dist: - aiohttp>=3.12.11 - bokeh>=3.7.3 @@ -2862,6 +2864,7 @@ packages: - s3fs>=2025.10.0,<2026 - xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial - cupy-xarray>=0.1.4,<0.2 + - memray>=1.19.1,<2 requires_python: '>=3.13,<3.14' editable: true - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 @@ -6344,6 +6347,58 @@ packages: - pkg:pypi/mdurl?source=hash-mapping size: 14465 timestamp: 1733255681319 +- pypi: https://files.pythonhosted.org/packages/52/40/617b15e62d5de1718e81ee436a1f19d4d40274ead97ac0eda188baebb986/memray-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl + name: memray + version: 1.19.1 + sha256: 9770b6046a8b7d7bb982c2cefb324a60a0ee7cd7c35c58c0df058355a9a54676 + requires_dist: + - jinja2>=2.9 + - typing-extensions ; python_full_version < '3.8' + - rich>=11.2.0 + - textual>=0.41.0 + - cython ; extra == 'test' + - greenlet ; python_full_version < '3.14' and extra == 'test' + - pytest ; extra == 'test' + - pytest-cov ; extra == 'test' + - ipython ; extra == 'test' + - setuptools ; extra == 'test' + - pytest-textual-snapshot ; extra == 'test' + - textual>=0.43,!=0.65.2,!=0.66 ; extra == 'test' + - packaging ; extra == 'test' + - ipython ; extra == 'docs' + - bump2version ; extra == 'docs' + - sphinx ; extra == 'docs' + - furo ; extra == 'docs' + - sphinx-argparse ; extra == 'docs' + - towncrier ; extra == 'docs' + - black ; extra == 'lint' + - flake8 ; extra == 'lint' + - isort ; extra == 'lint' + - mypy ; extra == 'lint' + - check-manifest ; extra == 'lint' + - asv ; extra == 'benchmark' + - cython ; extra == 'dev' + - greenlet ; python_full_version < '3.14' and extra == 'dev' + - pytest ; extra == 'dev' + - pytest-cov ; extra == 'dev' + - ipython ; extra == 'dev' + - setuptools ; extra == 'dev' + - pytest-textual-snapshot ; extra == 'dev' + - textual>=0.43,!=0.65.2,!=0.66 ; extra == 'dev' + - packaging ; extra == 'dev' + - black ; extra == 'dev' + - flake8 ; extra == 'dev' + - isort ; extra == 'dev' + - mypy ; extra == 'dev' + - check-manifest ; extra == 'dev' + - ipython ; extra == 'dev' + - bump2version ; extra == 'dev' + - sphinx ; extra == 'dev' + - furo ; extra == 'dev' + - sphinx-argparse ; extra == 'dev' + - towncrier ; extra == 'dev' + - asv ; extra == 'dev' + requires_python: '>=3.7.0' - pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl name: mercantile version: 1.2.1 @@ -8866,6 +8921,34 @@ packages: - pkg:pypi/terminado?source=hash-mapping size: 22452 timestamp: 1710262728753 +- pypi: https://files.pythonhosted.org/packages/53/b3/95ab646b0c908823d71e49ab8b5949ec9f33346cee3897d1af6be28a8d91/textual-6.6.0-py3-none-any.whl + name: textual + version: 6.6.0 + sha256: 5a9484bd15ee8a6fd8ac4ed4849fb25ee56bed2cecc7b8a83c4cd7d5f19515e5 + requires_dist: + - markdown-it-py[linkify]>=2.1.0 + - mdit-py-plugins + - platformdirs>=3.6.0,<5 + - pygments>=2.19.2,<3.0.0 + - rich>=14.2.0 + - tree-sitter>=0.25.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-bash>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-css>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-go>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-html>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-java>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-javascript>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-json>=0.24.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-markdown>=0.3.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-python>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-regex>=0.24.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-rust>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-sql>=0.3.11 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-toml>=0.6.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-xml>=0.7.0 ; python_full_version >= '3.10' and extra == 'syntax' + - tree-sitter-yaml>=0.6.0 ; python_full_version >= '3.10' and extra == 'syntax' + - typing-extensions>=4.4.0,<5.0.0 + requires_python: '>=3.9,<4.0' - conda: https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda sha256: 6016672e0e72c4cf23c0cf7b1986283bd86a9c17e8d319212d78d8e9ae42fdfd md5: 9d64911b31d57ca443e9f1e36b04385f diff --git a/pyproject.toml b/pyproject.toml index 2484a90..d72c0d0 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "xgboost>=3.1.1,<4", "s3fs>=2025.10.0,<2026", "xarray-spatial", - "cupy-xarray>=0.1.4,<0.2", + "cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", ] [project.scripts] diff --git a/src/entropice/aggregators.py b/src/entropice/aggregators.py index 9534880..e9cac8b 100644 --- a/src/entropice/aggregators.py +++ b/src/entropice/aggregators.py @@ -1,16 +1,27 @@ +"""Aggregation helpers.""" + +import gc +import os +import sys +from collections import defaultdict +from collections.abc import Callable, Generator from concurrent.futures import ProcessPoolExecutor, as_completed from dataclasses import dataclass, field +from functools import cache from typing import Literal -import cupy_xarray import geopandas as gpd import numpy as np import odc.geo.geobox +import pandas as pd +import psutil import shapely import shapely.ops +import sklearn import xarray as xr import xdggs import xvec +from rich import print from rich.progress import track from shapely.geometry import LineString, Polygon from stopuhr import stopwatch @@ -19,34 +30,100 @@ from xdggs.healpix import HealpixInfo from entropice import grids -@dataclass +@dataclass(frozen=True) class _Aggregations: + # ! The ordering is super important for this class! mean: bool = True sum: bool = False std: bool = False min: bool = False max: bool = False - quantiles: list[float] = field(default_factory=lambda: []) + median: bool = False # Will be mapped as quantile 0.5 + quantiles: tuple[float] = field(default_factory=tuple) - def varnames(self, vars: list[str] | str) -> list[str]: - if isinstance(vars, str): - vars = [vars] - agg_vars = [] - for var in vars: - if self.mean: - agg_vars.append(f"{var}_mean") - if self.sum: - agg_vars.append(f"{var}_sum") - if self.std: - agg_vars.append(f"{var}_std") - if self.min: - agg_vars.append(f"{var}_min") - if self.max: - agg_vars.append(f"{var}_max") - for q in self.quantiles: - q_int = int(q * 100) - agg_vars.append(f"{var}_p{q_int}") - return agg_vars + def __post_init__(self): + assert isinstance(self.quantiles, tuple), "Quantiles must be a tuple, ortherwise the class is not hashable" + if self.median and 0.5 in self.quantiles: + raise ValueError("Median aggregation cannot be used together with quantile 0.5") + + @cache + def __len__(self) -> int: + length = 0 + if self.mean: + length += 1 + if self.sum: + length += 1 + if self.std: + length += 1 + if self.min: + length += 1 + if self.max: + length += 1 + if self.median and 0.5 not in self.quantiles: + length += 1 + length += len(self.quantiles) + return length + + def aggnames(self) -> list[str]: + names = [] + if self.mean: + names.append("mean") + if self.sum: + names.append("sum") + if self.std: + names.append("std") + if self.min: + names.append("min") + if self.max: + names.append("max") + if self.median: + names.append("median") + for q in sorted(self.quantiles): # ? Ordering is important here! + q_int = int(q * 100) + names.append(f"p{q_int}") + return names + + def _agg_cell_data_single(self, flattened_var: xr.DataArray) -> np.ndarray: + others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"]) + cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32) + j = 0 + if self.mean: + cell_data[j, ...] = flattened_var.mean(dim="z", skipna=True).to_numpy() + j += 1 + if self.sum: + cell_data[j, ...] = flattened_var.sum(dim="z", skipna=True).to_numpy() + j += 1 + if self.std: + cell_data[j, ...] = flattened_var.std(dim="z", skipna=True).to_numpy() + j += 1 + if self.min: + cell_data[j, ...] = flattened_var.min(dim="z", skipna=True).to_numpy() + j += 1 + if self.max: + cell_data[j, ...] = flattened_var.max(dim="z", skipna=True).to_numpy() + j += 1 + if len(self.quantiles) > 0 or self.median: + quantiles_to_compute = sorted(self.quantiles) # ? Ordering is important here! + if self.median: + quantiles_to_compute.insert(0, 0.5) + cell_data[j:, ...] = flattened_var.quantile( + q=quantiles_to_compute, + dim="z", + skipna=True, + ).to_numpy() + + def agg_cell_data(self, flattened: xr.Dataset | xr.DataArray) -> np.ndarray: + if isinstance(flattened, xr.DataArray): + return self._agg_cell_data_single(flattened) + + others_shape = tuple([flattened.sizes[dim] for dim in flattened.dims if dim != "z"]) + cell_data = np.full((len(flattened.data_vars), len(self), *others_shape), np.nan, dtype=np.float32) + for i, var in enumerate(flattened.data_vars): + cell_data[i, ...] = self._agg_cell_data_single(flattened[var]) + # Transform to numpy arrays + # if flattened.cupy.is_cupy: + # cell_data = cell_data.get() + return cell_data def _crosses_antimeridian(geom: Polygon) -> bool: @@ -77,11 +154,11 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool: return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1 -@stopwatch("Getting corrected geometries", log=False) +@stopwatch("Correcting geometries", log=False) def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]: geom, gbox, crs = inp # cell.geometry is a shapely Polygon - if not _crosses_antimeridian(geom): + if crs != "EPSG:4326" or not _crosses_antimeridian(geom): geoms = [geom] # Split geometry in case it crossed antimeridian else: @@ -92,153 +169,338 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis return geoms -@stopwatch("Correcting geometries") -def get_corrected_geometries(grid_gdf: gpd.GeoDataFrame, gbox: odc.geo.geobox.GeoBox): - """Get corrected geometries for antimeridian-crossing polygons. - - Args: - grid_gdf (gpd.GeoDataFrame): Grid GeoDataFrame. - gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference. - - Returns: - list[list[odc.geo.Geometry]]: List of corrected, georeferenced geometries. - - """ - with ProcessPoolExecutor(max_workers=20) as executor: - cell_geometries = list( - executor.map(_get_corrected_geoms, [(row.geometry, gbox, grid_gdf.crs) for _, row in grid_gdf.iterrows()]) - ) - return cell_geometries - - -@stopwatch("Aggregating cell data", log=False) -def _agg_cell_data(flattened: xr.Dataset, aggregations: _Aggregations): - cell_data = {} - for var in flattened.data_vars: - if aggregations.mean: - cell_data[f"{var}_mean"] = flattened[var].mean(dim="z", skipna=True) - if aggregations.sum: - cell_data[f"{var}_sum"] = flattened[var].sum(dim="z", skipna=True) - if aggregations.std: - cell_data[f"{var}_std"] = flattened[var].std(dim="z", skipna=True) - if aggregations.min: - cell_data[f"{var}_min"] = flattened[var].min(dim="z", skipna=True) - if aggregations.max: - cell_data[f"{var}_max"] = flattened[var].max(dim="z", skipna=True) - if len(aggregations.quantiles) > 0: - quantile_values = flattened[var].quantile( - q=aggregations.quantiles, - dim="z", - skipna=True, - ) - for q, qv in zip(aggregations.quantiles, quantile_values): - q_int = int(q * 100) - cell_data[f"{var}_p{q_int}"] = qv - # Transform to numpy arrays - for key in cell_data: - cell_data[key] = cell_data[key].cupy.as_numpy().values - return cell_data - - @stopwatch("Extracting cell data", log=False) -def _extract_cell_data(ds: xr.Dataset, geom: odc.geo.Geometry, aggregations: _Aggregations): - spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"] - cropped: xr.Dataset = ds.odc.crop(geom).drop_vars("spatial_ref") - flattened = cropped.stack(z=spatdims) - if flattened.z.size > 3000: - flattened = flattened.cupy.as_cupy() - cell_data = _agg_cell_data(flattened, aggregations) +def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggregations): + spatdims = ["latitude", "longitude"] if "latitude" in cropped.dims and "longitude" in cropped.dims else ["y", "x"] + flattened = cropped.stack(z=spatdims) # noqa: PD013 + # if flattened.z.size > 3000: + # flattened = flattened.cupy.as_cupy() + cell_data = aggregations.agg_cell_data(flattened) return cell_data - # with np.errstate(divide="ignore", invalid="ignore"): - # cell_data = cropped.mean(spatdims) - # return {var: cell_data[var].values for var in cell_data.data_vars} - @stopwatch("Extracting split cell data", log=False) -def _extract_split_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry], aggregations: _Aggregations): - spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"] - cropped: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(ds, geoms)] - flattened = xr.concat([c.stack(z=spatdims) for c in cropped], dim="z") - cell_data = _agg_cell_data(flattened, aggregations) +def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggregations: _Aggregations): + spatdims = ( + ["latitude", "longitude"] + if "latitude" in cropped_list[0].dims and "longitude" in cropped_list[0].dims + else ["y", "x"] + ) + flattened = xr.concat([c.stack(z=spatdims) for c in cropped_list], dim="z") # noqa: PD013 + cell_data = aggregations.agg_cell_data(flattened) return cell_data - # partial_counts = [part.notnull().sum(dim=spatdims) for part in parts] - # with np.errstate(divide="ignore", invalid="ignore"): - # partial_means = [part.sum(dim=spatdims) for part in parts] - # n = xr.concat(partial_counts, dim="part").sum("part") - # cell_data = xr.concat(partial_means, dim="part").sum("part") / n - # return {var: cell_data[var].values for var in cell_data.data_vars} + +@stopwatch("Cropping (and reading?) cell data", log=False) +def _read_cell_data(unaligned: xr.Dataset | xr.DataArray, geom: odc.geo.Geometry) -> xr.Dataset: + cropped: xr.Dataset = unaligned.odc.crop(geom, apply_mask=True).drop_vars("spatial_ref") + return cropped.compute() + + +@stopwatch("Cropping (and reading?) split cell data", log=False) +def _read_split_cell_data(unaligned: xr.Dataset | xr.DataArray, geoms: list[odc.geo.Geometry]) -> list[xr.Dataset]: + cropped: list[xr.Dataset] = [unaligned.odc.crop(geom, apply_mask=True).drop_vars("spatial_ref") for geom in geoms] + return [c.compute() for c in cropped] + + +@stopwatch("Processing cell", log=False) +def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregations: _Aggregations): + geoms = _get_corrected_geoms((poly, unaligned.odc.geobox, unaligned.odc.crs)) + if len(geoms) == 0: + raise ValueError("No geometries to process") + elif len(geoms) == 1: + cropped = _read_cell_data(unaligned, geoms[0]) + cell_data = _extract_cell_data(cropped, aggregations) + else: + cropped = _read_split_cell_data(unaligned, geoms) + cell_data = _extract_split_cell_data(cropped, aggregations) + return cell_data + + +def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]: + # ! TODO: Adjust for antimeridian crossing, this is only needed if CRS is not 3413 + # Simple partitioning by splitting the GeoDataFrame into n_partitions parts + centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y}) + labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids) + for i in range(n_partitions): + partition = grid_gdf[labels == i] + yield partition + + +class _MemoryProfiler: + def __init__(self): + self.process = psutil.Process() + self._logs = defaultdict(list) + + def log_memory(self, message: str, log: bool = True): + mem = self.process.memory_info().rss / 1024**3 + self._logs[message].append(mem) + if log: + print(f"[MemoryProfiler] {message}: {mem:.2f} GB") + + def summary(self): + summary_lines = ["[MemoryProfiler] Memory usage summary:"] + for message, mems in self._logs.items(): + mems_str = [f"{m:.2f}" for m in mems] + summary_lines.append(f" {message}: min={min(mems):.2f} GB, max={max(mems):.2f} GB") + summary_lines.append(f" Measurements: {', '.join(mems_str)}") + return "\n".join(summary_lines) + + +memprof = None + + +def _init_worker(): + global memprof + memprof = _MemoryProfiler() + + +def _align_partition( + grid_partition_gdf: gpd.GeoDataFrame, + raster: xr.Dataset | Callable[[], xr.Dataset], + aggregations: _Aggregations, + pxbuffer: int, +): + # Strategy for each cell: + # 1. Correct the geometry to account for the antimeridian + # 2. Cop the dataset and load the data into memory + # 3. Aggregate the data + # 4. Store the aggregated data in the correct location in the output arrays + + # Steps 1-3 are done in different Processes, since all of these steps are CPU bound + # Even the reading part is CPU bound, since it involves applying masks and cropping the data + # Also, the backend (icechunk or zarr) can do some computations, like combining chunks together + + # For the Arcticdem on Hex3 grid the avg. time per task is: + # 1. ~0.01s + # 2. ~1-3s + # 3. ~10-20s + # And for Hex6: + # 1. unmeasurable + # 2. ~0.1-0.3s + # 3. ~0.05s + + memprof.log_memory("Before reading partial raster", log=False) + + if callable(raster) and not isinstance(raster, xr.Dataset): + raster = raster() + + if hasattr(raster, "__dask_graph__"): + graph_size = len(raster.__dask_graph__()) + graph_memory = sys.getsizeof(raster.__dask_graph__()) / 1024**2 + print(f"Raster graph size: {graph_size} tasks, {graph_memory:.2f} MB") + + others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]) + data_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape) + data = np.full(data_shape, np.nan, dtype=np.float32) + + partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs) + partial_extent = partial_extent.buffered( + raster.odc.geobox.resolution.x * pxbuffer, + raster.odc.geobox.resolution.y * pxbuffer, + ) # buffer by pxbuffer pixels + with stopwatch("Cropping raster to partition extent", log=False): + try: + partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute() + except Exception as e: + print(f"Error cropping raster to partition extent: {e}") + return data + + print(f"{os.getpid()}: Partial raster size: {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)") + memprof.log_memory("After reading partial raster", log=True) + + with ProcessPoolExecutor( + max_workers=60, + ) as executor: + futures = {} + for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()): + # ? Splitting the geometries already here and passing only the cropped data to the worker to + # reduce pickling overhead + geoms = _get_corrected_geoms((row.geometry, partial_raster.odc.geobox, partial_raster.odc.crs)) + if len(geoms) == 0: + continue + elif len(geoms) == 1: + cropped = [_read_cell_data(partial_raster, geoms[0])] + else: + cropped = _read_split_cell_data(partial_raster, geoms) + futures[executor.submit(_extract_split_cell_data, cropped, aggregations)] = i + + # Clean up immediately after submitting + del geoms, cropped # Added this line + + for future in as_completed(futures): + i = futures[future] + try: + cell_data = future.result() + except (SystemError, SystemExit, KeyboardInterrupt) as e: + raise e + except Exception as e: + print(f"Error processing cell {i}: {e}") + continue + data[i, ...] = cell_data + del cell_data, futures[future] + + # for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()): + # try: + # cell_data = _process_geom(row.geometry, partial_raster, aggregations) + # except (SystemError, SystemExit, KeyboardInterrupt) as e: + # raise e + # except Exception as e: + # print(f"Error processing cell {row['cell_id']}: {e}") + # continue + # data[i, ...] = cell_data + + partial_raster.close() + raster.close() + del partial_raster, raster, future, futures + gc.collect() + memprof.log_memory("After cleaning", log=True) + + print("Finished processing partition") + print("### Stopwatch summary ###\n") + print(stopwatch.summary()) + print("### Memory summary ###\n") + print(memprof.summary()) + print("#########################") + # with ProcessPoolExecutor( + # max_workers=max_workers, + # initializer=_init_partial_raster_global, + # initargs=(raster, grid_partition_gdf, pxbuffer), + # ) as executor: + # futures = { + # executor.submit(_process_geom_global, row.geometry, aggregations): i + # for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()) + # } + + # others_shape = tuple( + # [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]] + # ) + # data_shape = (len(futures), len(raster.data_vars), len(aggregations), *others_shape) + # print(f"Allocating partial data array of shape {data_shape}") + # data = np.full(data_shape, np.nan, dtype=np.float32) + + # lap = time.time() + # for future in track( + # as_completed(futures), + # total=len(futures), + # description="Spatially aggregating ERA5 data...", + # ): + # print(f"time since last cell: {time.time() - lap:.2f}s") + # i = futures[future] + # try: + # cell_data = future.result() + # data[i, ...] = cell_data + # except Exception as e: + # print(f"Error processing cell {i}: {e}") + # continue + # finally: + # # Try to free memory + # del futures[future], future + # if "cell_data" in locals(): + # del cell_data + # gc.collect() + # lap = time.time() + return data @stopwatch("Aligning data with grid") def _align_data( grid_gdf: gpd.GeoDataFrame, - unaligned: xr.Dataset, + raster: xr.Dataset | Callable[[], xr.Dataset], aggregations: _Aggregations, -) -> dict[str, np.ndarray]: - # Persist the dataset, as all the operations here MUST NOT be lazy - unaligned = unaligned.load() + n_partitions: int, + max_workers: int, + pxbuffer: int, +): + partial_data = {} - cell_geometries = get_corrected_geometries(grid_gdf, unaligned.odc.geobox) - other_dims_shape = tuple( - [unaligned.sizes[dim] for dim in unaligned.dims if dim not in ["y", "x", "latitude", "longitude"]] - ) - data_shape = (len(cell_geometries), *other_dims_shape) - data = {var: np.full(data_shape, np.nan, dtype=np.float32) for var in aggregations.varnames(unaligned.data_vars)} - with ProcessPoolExecutor(max_workers=10) as executor: + with ProcessPoolExecutor( + max_workers=max_workers, + initializer=_init_worker, + # initializer=_init_raster_global, + # initargs=(raster,), + ) as executor: futures = {} - for i, geoms in track( - enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..." - ): - if len(geoms) == 0: - continue - elif len(geoms) == 1: - geom = geoms[0] - # Reduce the amount of data needed to be sent to the worker - # Since we dont mask the data, only isel operations are done here - # Thus, to properly extract the subset, another masked crop needs to be done in the worker - unaligned_subset = unaligned.odc.crop(geom, apply_mask=False) - fut = executor.submit(_extract_cell_data, unaligned_subset, geom, aggregations) - else: - # Same as above but for multiple parts - unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms] - fut = executor.submit(_extract_split_cell_data, unaligned_subsets, geoms, aggregations) - futures[fut] = i + for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)): + futures[ + executor.submit( + _align_partition, + grid_partition, + raster.copy() if isinstance(raster, xr.Dataset) else raster, + aggregations, + pxbuffer, + ) + ] = i + + print("Submitted all partitions, waiting for results...") for future in track( - as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..." + as_completed(futures), + total=len(futures), + description="Processing grid partitions...", ): - i = futures[future] - cell_data = future.result() - for var in unaligned.data_vars: - data[var][i, :] = cell_data[var] + try: + i = futures[future] + print(f"Processed partition {i + 1}/{len(futures)}") + part_data = future.result() + partial_data[i] = part_data + except Exception as e: + print(f"Error processing partition {i}: {e}") + raise e + + data = np.concatenate([partial_data[i] for i in range(len(partial_data))], axis=0) return data def aggregate_raster_into_grid( - raster: xr.Dataset, + raster: xr.Dataset | Callable[[], xr.Dataset], grid_gdf: gpd.GeoDataFrame, aggregations: _Aggregations, grid: Literal["hex", "healpix"], level: int, + n_partitions: int = 20, + max_workers: int = 5, + pxbuffer: int = 15, ): - aligned = _align_data(grid_gdf, raster, aggregations) + """Aggregate raster data into grid cells. + + Args: + raster (xr.Dataset | Callable[[], xr.Dataset]): Raster data or a function that returns it. + grid_gdf (gpd.GeoDataFrame): The grid to aggregate into. + aggregations (_Aggregations): The aggregations to perform. + grid (Literal["hex", "healpix"]): The type of grid to use. + level (int): The level of the grid. + n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20. + max_workers (int, optional): Maximum number of worker processes. Defaults to 5. + pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15. + + Returns: + xr.Dataset: Aggregated data aligned with the grid. + + """ + aligned = _align_data( + grid_gdf, + raster, + aggregations, + n_partitions=n_partitions, + max_workers=max_workers, + pxbuffer=pxbuffer, + ) + # Dims of aligned: (cell_ids, variables, aggregations, other dims...) + + if callable(raster) and not isinstance(raster, xr.Dataset): + raster = raster() cell_ids = grids.get_cell_ids(grid, level) - dims = ["cell_ids"] - coords = {"cell_ids": cell_ids} - for dim in raster.dims: - if dim not in ["y", "x", "latitude", "longitude"]: - dims.append(dim) - coords[dim] = raster.coords[dim] + dims = ["cell_ids", "variables", "aggregations"] + coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()} + for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}: + dims.append(dim) + coords[dim] = raster.coords[dim] + + ongrid = xr.DataArray(aligned, dims=dims, coords=coords).to_dataset("variables") - data_vars = {var: (dims, values) for var, values in aligned.items()} - ongrid = xr.Dataset( - data_vars, - coords=coords, - ) gridinfo = { "grid_name": "h3" if grid == "hex" else grid, "level": level, @@ -247,8 +509,7 @@ def aggregate_raster_into_grid( gridinfo["indexing_scheme"] = "nested" ongrid.cell_ids.attrs = gridinfo for var in raster.data_vars: - for v in aggregations.varnames(var): - ongrid[v].attrs = raster[var].attrs + ongrid[var].attrs = raster[var].attrs ongrid = xdggs.decode(ongrid) return ongrid diff --git a/src/entropice/alphaearth.py b/src/entropice/alphaearth.py index 1b48b2e..d568cf9 100644 --- a/src/entropice/alphaearth.py +++ b/src/entropice/alphaearth.py @@ -77,7 +77,7 @@ def download(grid: Literal["hex", "healpix"], level: int): return feature.set(mean_dict) # Process grid in batches of 100 - batch_size = 100 + batch_size = 50 all_results = [] n_batches = len(grid_gdf) // batch_size for batch_num, batch_grid in track( diff --git a/src/entropice/arcticdem.py b/src/entropice/arcticdem.py index 824581b..3d22aa5 100644 --- a/src/entropice/arcticdem.py +++ b/src/entropice/arcticdem.py @@ -1,4 +1,5 @@ import datetime +import multiprocessing as mp from dataclasses import dataclass from math import ceil from typing import Literal @@ -9,6 +10,7 @@ import cupyx.scipy.signal import cyclopts import dask.array import dask.distributed as dd +import icechunk import icechunk.xarray import numpy as np import smart_geocubes @@ -301,18 +303,23 @@ def enrich(): print("Enrichment complete.") -@cli.command() -def aggregate(grid: Literal["hex", "healpix"], level: int): - with stopwatch(f"Loading {grid} grid at level {level}"): - grid_gdf = grids.open(grid, level) - # ? Mask out water, since we don't want to aggregate over oceans - grid_gdf = watermask.clip_grid(grid_gdf) - +def _open_adem(): arcticdem_store = get_arcticdem_stores() accessor = smart_geocubes.ArcticDEM32m(arcticdem_store) adem = accessor.open_xarray() assert {"x", "y"} == set(adem.dims) assert adem.odc.crs == "EPSG:3413" + return adem + + +@cli.command() +def aggregate(grid: Literal["hex", "healpix"], level: int, max_workers: int = 20): + mp.set_start_method("forkserver", force=True) + + with stopwatch(f"Loading {grid} grid at level {level}"): + grid_gdf = grids.open(grid, level) + # ? Mask out water, since we don't want to aggregate over oceans + grid_gdf = watermask.clip_grid(grid_gdf) aggregations = _Aggregations( mean=True, @@ -320,13 +327,29 @@ def aggregate(grid: Literal["hex", "healpix"], level: int): std=True, min=True, max=True, - quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99], + median=True, + quantiles=(0.01, 0.05, 0.25, 0.75, 0.95, 0.99), + ) + + aggregated = aggregate_raster_into_grid( + _open_adem, + grid_gdf, + aggregations, + grid, + level, + max_workers=max_workers, + n_partitions=200, # Results in 10GB large patches + pxbuffer=30, ) - aggregated = aggregate_raster_into_grid(adem, grid_gdf, aggregations, grid, level) store = get_arcticdem_stores(grid, level) with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"): aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated)) + print("Aggregation complete.") + + print("### Finished ArcticDEM processing ###") + stopwatch.summary() + if __name__ == "__main__": cli()