leaking grid aggregations
This commit is contained in:
parent
18cc1b8601
commit
341fa7b836
5 changed files with 521 additions and 154 deletions
85
pixi.lock
generated
85
pixi.lock
generated
|
|
@ -543,6 +543,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/ab/b5/36c712098e6191d1b4e349304ef73a8d06aed77e56ceaac8c0a306c7bda1/jupyterlab_widgets-3.0.16-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c4/bd/ba44a47578ea48ee28b54543c1de8c529eedad8317516a2a753e6d9c77c5/lonboard-0.13.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/52/40/617b15e62d5de1718e81ee436a1f19d4d40274ead97ac0eda188baebb986/memray-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl
|
||||
|
|
@ -594,6 +595,7 @@ environments:
|
|||
- pypi: https://files.pythonhosted.org/packages/39/60/868371b6482ccd9ef423c6f62650066cf8271fdb2ee84f192695ad6b7a96/streamlit-1.51.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/53/b3/95ab646b0c908823d71e49ab8b5949ec9f33346cee3897d1af6be28a8d91/textual-6.6.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/06/af/413f6b172f9d4c4943b980a9fd96bb4d915680ce8f79c07de6f697b45c8b/ultraplot-1.65.1-py3-none-any.whl
|
||||
|
|
@ -2808,7 +2810,7 @@ packages:
|
|||
- pypi: ./
|
||||
name: entropice
|
||||
version: 0.1.0
|
||||
sha256: d95c691c76206bf54e207fe02b100a247a3847f37135d1cdf6ee18165770ea46
|
||||
sha256: 788df3ea7773f54fce8274d57d51af8edbab106c5b2082f8efca6639aa3eece9
|
||||
requires_dist:
|
||||
- aiohttp>=3.12.11
|
||||
- bokeh>=3.7.3
|
||||
|
|
@ -2862,6 +2864,7 @@ packages:
|
|||
- s3fs>=2025.10.0,<2026
|
||||
- xarray-spatial @ git+https://github.com/relativityhd/xarray-spatial
|
||||
- cupy-xarray>=0.1.4,<0.2
|
||||
- memray>=1.19.1,<2
|
||||
requires_python: '>=3.13,<3.14'
|
||||
editable: true
|
||||
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
|
||||
|
|
@ -6344,6 +6347,58 @@ packages:
|
|||
- pkg:pypi/mdurl?source=hash-mapping
|
||||
size: 14465
|
||||
timestamp: 1733255681319
|
||||
- pypi: https://files.pythonhosted.org/packages/52/40/617b15e62d5de1718e81ee436a1f19d4d40274ead97ac0eda188baebb986/memray-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||
name: memray
|
||||
version: 1.19.1
|
||||
sha256: 9770b6046a8b7d7bb982c2cefb324a60a0ee7cd7c35c58c0df058355a9a54676
|
||||
requires_dist:
|
||||
- jinja2>=2.9
|
||||
- typing-extensions ; python_full_version < '3.8'
|
||||
- rich>=11.2.0
|
||||
- textual>=0.41.0
|
||||
- cython ; extra == 'test'
|
||||
- greenlet ; python_full_version < '3.14' and extra == 'test'
|
||||
- pytest ; extra == 'test'
|
||||
- pytest-cov ; extra == 'test'
|
||||
- ipython ; extra == 'test'
|
||||
- setuptools ; extra == 'test'
|
||||
- pytest-textual-snapshot ; extra == 'test'
|
||||
- textual>=0.43,!=0.65.2,!=0.66 ; extra == 'test'
|
||||
- packaging ; extra == 'test'
|
||||
- ipython ; extra == 'docs'
|
||||
- bump2version ; extra == 'docs'
|
||||
- sphinx ; extra == 'docs'
|
||||
- furo ; extra == 'docs'
|
||||
- sphinx-argparse ; extra == 'docs'
|
||||
- towncrier ; extra == 'docs'
|
||||
- black ; extra == 'lint'
|
||||
- flake8 ; extra == 'lint'
|
||||
- isort ; extra == 'lint'
|
||||
- mypy ; extra == 'lint'
|
||||
- check-manifest ; extra == 'lint'
|
||||
- asv ; extra == 'benchmark'
|
||||
- cython ; extra == 'dev'
|
||||
- greenlet ; python_full_version < '3.14' and extra == 'dev'
|
||||
- pytest ; extra == 'dev'
|
||||
- pytest-cov ; extra == 'dev'
|
||||
- ipython ; extra == 'dev'
|
||||
- setuptools ; extra == 'dev'
|
||||
- pytest-textual-snapshot ; extra == 'dev'
|
||||
- textual>=0.43,!=0.65.2,!=0.66 ; extra == 'dev'
|
||||
- packaging ; extra == 'dev'
|
||||
- black ; extra == 'dev'
|
||||
- flake8 ; extra == 'dev'
|
||||
- isort ; extra == 'dev'
|
||||
- mypy ; extra == 'dev'
|
||||
- check-manifest ; extra == 'dev'
|
||||
- ipython ; extra == 'dev'
|
||||
- bump2version ; extra == 'dev'
|
||||
- sphinx ; extra == 'dev'
|
||||
- furo ; extra == 'dev'
|
||||
- sphinx-argparse ; extra == 'dev'
|
||||
- towncrier ; extra == 'dev'
|
||||
- asv ; extra == 'dev'
|
||||
requires_python: '>=3.7.0'
|
||||
- pypi: https://files.pythonhosted.org/packages/b2/d6/de0cc74f8d36976aeca0dd2e9cbf711882ff8e177495115fd82459afdc4d/mercantile-1.2.1-py3-none-any.whl
|
||||
name: mercantile
|
||||
version: 1.2.1
|
||||
|
|
@ -8866,6 +8921,34 @@ packages:
|
|||
- pkg:pypi/terminado?source=hash-mapping
|
||||
size: 22452
|
||||
timestamp: 1710262728753
|
||||
- pypi: https://files.pythonhosted.org/packages/53/b3/95ab646b0c908823d71e49ab8b5949ec9f33346cee3897d1af6be28a8d91/textual-6.6.0-py3-none-any.whl
|
||||
name: textual
|
||||
version: 6.6.0
|
||||
sha256: 5a9484bd15ee8a6fd8ac4ed4849fb25ee56bed2cecc7b8a83c4cd7d5f19515e5
|
||||
requires_dist:
|
||||
- markdown-it-py[linkify]>=2.1.0
|
||||
- mdit-py-plugins
|
||||
- platformdirs>=3.6.0,<5
|
||||
- pygments>=2.19.2,<3.0.0
|
||||
- rich>=14.2.0
|
||||
- tree-sitter>=0.25.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-bash>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-css>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-go>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-html>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-java>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-javascript>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-json>=0.24.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-markdown>=0.3.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-python>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-regex>=0.24.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-rust>=0.23.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-sql>=0.3.11 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-toml>=0.6.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-xml>=0.7.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- tree-sitter-yaml>=0.6.0 ; python_full_version >= '3.10' and extra == 'syntax'
|
||||
- typing-extensions>=4.4.0,<5.0.0
|
||||
requires_python: '>=3.9,<4.0'
|
||||
- conda: https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda
|
||||
sha256: 6016672e0e72c4cf23c0cf7b1986283bd86a9c17e8d319212d78d8e9ae42fdfd
|
||||
md5: 9d64911b31d57ca443e9f1e36b04385f
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ dependencies = [
|
|||
"xgboost>=3.1.1,<4",
|
||||
"s3fs>=2025.10.0,<2026",
|
||||
"xarray-spatial",
|
||||
"cupy-xarray>=0.1.4,<0.2",
|
||||
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
|
|
@ -1,16 +1,27 @@
|
|||
"""Aggregation helpers."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache
|
||||
from typing import Literal
|
||||
|
||||
import cupy_xarray
|
||||
import geopandas as gpd
|
||||
import numpy as np
|
||||
import odc.geo.geobox
|
||||
import pandas as pd
|
||||
import psutil
|
||||
import shapely
|
||||
import shapely.ops
|
||||
import sklearn
|
||||
import xarray as xr
|
||||
import xdggs
|
||||
import xvec
|
||||
from rich import print
|
||||
from rich.progress import track
|
||||
from shapely.geometry import LineString, Polygon
|
||||
from stopuhr import stopwatch
|
||||
|
|
@ -19,34 +30,100 @@ from xdggs.healpix import HealpixInfo
|
|||
from entropice import grids
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class _Aggregations:
|
||||
# ! The ordering is super important for this class!
|
||||
mean: bool = True
|
||||
sum: bool = False
|
||||
std: bool = False
|
||||
min: bool = False
|
||||
max: bool = False
|
||||
quantiles: list[float] = field(default_factory=lambda: [])
|
||||
median: bool = False # Will be mapped as quantile 0.5
|
||||
quantiles: tuple[float] = field(default_factory=tuple)
|
||||
|
||||
def varnames(self, vars: list[str] | str) -> list[str]:
|
||||
if isinstance(vars, str):
|
||||
vars = [vars]
|
||||
agg_vars = []
|
||||
for var in vars:
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.quantiles, tuple), "Quantiles must be a tuple, ortherwise the class is not hashable"
|
||||
if self.median and 0.5 in self.quantiles:
|
||||
raise ValueError("Median aggregation cannot be used together with quantile 0.5")
|
||||
|
||||
@cache
|
||||
def __len__(self) -> int:
|
||||
length = 0
|
||||
if self.mean:
|
||||
agg_vars.append(f"{var}_mean")
|
||||
length += 1
|
||||
if self.sum:
|
||||
agg_vars.append(f"{var}_sum")
|
||||
length += 1
|
||||
if self.std:
|
||||
agg_vars.append(f"{var}_std")
|
||||
length += 1
|
||||
if self.min:
|
||||
agg_vars.append(f"{var}_min")
|
||||
length += 1
|
||||
if self.max:
|
||||
agg_vars.append(f"{var}_max")
|
||||
for q in self.quantiles:
|
||||
length += 1
|
||||
if self.median and 0.5 not in self.quantiles:
|
||||
length += 1
|
||||
length += len(self.quantiles)
|
||||
return length
|
||||
|
||||
def aggnames(self) -> list[str]:
|
||||
names = []
|
||||
if self.mean:
|
||||
names.append("mean")
|
||||
if self.sum:
|
||||
names.append("sum")
|
||||
if self.std:
|
||||
names.append("std")
|
||||
if self.min:
|
||||
names.append("min")
|
||||
if self.max:
|
||||
names.append("max")
|
||||
if self.median:
|
||||
names.append("median")
|
||||
for q in sorted(self.quantiles): # ? Ordering is important here!
|
||||
q_int = int(q * 100)
|
||||
agg_vars.append(f"{var}_p{q_int}")
|
||||
return agg_vars
|
||||
names.append(f"p{q_int}")
|
||||
return names
|
||||
|
||||
def _agg_cell_data_single(self, flattened_var: xr.DataArray) -> np.ndarray:
|
||||
others_shape = tuple([flattened_var.sizes[dim] for dim in flattened_var.dims if dim != "z"])
|
||||
cell_data = np.full((len(self), *others_shape), np.nan, dtype=np.float32)
|
||||
j = 0
|
||||
if self.mean:
|
||||
cell_data[j, ...] = flattened_var.mean(dim="z", skipna=True).to_numpy()
|
||||
j += 1
|
||||
if self.sum:
|
||||
cell_data[j, ...] = flattened_var.sum(dim="z", skipna=True).to_numpy()
|
||||
j += 1
|
||||
if self.std:
|
||||
cell_data[j, ...] = flattened_var.std(dim="z", skipna=True).to_numpy()
|
||||
j += 1
|
||||
if self.min:
|
||||
cell_data[j, ...] = flattened_var.min(dim="z", skipna=True).to_numpy()
|
||||
j += 1
|
||||
if self.max:
|
||||
cell_data[j, ...] = flattened_var.max(dim="z", skipna=True).to_numpy()
|
||||
j += 1
|
||||
if len(self.quantiles) > 0 or self.median:
|
||||
quantiles_to_compute = sorted(self.quantiles) # ? Ordering is important here!
|
||||
if self.median:
|
||||
quantiles_to_compute.insert(0, 0.5)
|
||||
cell_data[j:, ...] = flattened_var.quantile(
|
||||
q=quantiles_to_compute,
|
||||
dim="z",
|
||||
skipna=True,
|
||||
).to_numpy()
|
||||
|
||||
def agg_cell_data(self, flattened: xr.Dataset | xr.DataArray) -> np.ndarray:
|
||||
if isinstance(flattened, xr.DataArray):
|
||||
return self._agg_cell_data_single(flattened)
|
||||
|
||||
others_shape = tuple([flattened.sizes[dim] for dim in flattened.dims if dim != "z"])
|
||||
cell_data = np.full((len(flattened.data_vars), len(self), *others_shape), np.nan, dtype=np.float32)
|
||||
for i, var in enumerate(flattened.data_vars):
|
||||
cell_data[i, ...] = self._agg_cell_data_single(flattened[var])
|
||||
# Transform to numpy arrays
|
||||
# if flattened.cupy.is_cupy:
|
||||
# cell_data = cell_data.get()
|
||||
return cell_data
|
||||
|
||||
|
||||
def _crosses_antimeridian(geom: Polygon) -> bool:
|
||||
|
|
@ -77,11 +154,11 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
|
|||
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
|
||||
|
||||
|
||||
@stopwatch("Getting corrected geometries", log=False)
|
||||
@stopwatch("Correcting geometries", log=False)
|
||||
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]:
|
||||
geom, gbox, crs = inp
|
||||
# cell.geometry is a shapely Polygon
|
||||
if not _crosses_antimeridian(geom):
|
||||
if crs != "EPSG:4326" or not _crosses_antimeridian(geom):
|
||||
geoms = [geom]
|
||||
# Split geometry in case it crossed antimeridian
|
||||
else:
|
||||
|
|
@ -92,153 +169,338 @@ def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> lis
|
|||
return geoms
|
||||
|
||||
|
||||
@stopwatch("Correcting geometries")
|
||||
def get_corrected_geometries(grid_gdf: gpd.GeoDataFrame, gbox: odc.geo.geobox.GeoBox):
|
||||
"""Get corrected geometries for antimeridian-crossing polygons.
|
||||
|
||||
Args:
|
||||
grid_gdf (gpd.GeoDataFrame): Grid GeoDataFrame.
|
||||
gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference.
|
||||
|
||||
Returns:
|
||||
list[list[odc.geo.Geometry]]: List of corrected, georeferenced geometries.
|
||||
|
||||
"""
|
||||
with ProcessPoolExecutor(max_workers=20) as executor:
|
||||
cell_geometries = list(
|
||||
executor.map(_get_corrected_geoms, [(row.geometry, gbox, grid_gdf.crs) for _, row in grid_gdf.iterrows()])
|
||||
)
|
||||
return cell_geometries
|
||||
|
||||
|
||||
@stopwatch("Aggregating cell data", log=False)
|
||||
def _agg_cell_data(flattened: xr.Dataset, aggregations: _Aggregations):
|
||||
cell_data = {}
|
||||
for var in flattened.data_vars:
|
||||
if aggregations.mean:
|
||||
cell_data[f"{var}_mean"] = flattened[var].mean(dim="z", skipna=True)
|
||||
if aggregations.sum:
|
||||
cell_data[f"{var}_sum"] = flattened[var].sum(dim="z", skipna=True)
|
||||
if aggregations.std:
|
||||
cell_data[f"{var}_std"] = flattened[var].std(dim="z", skipna=True)
|
||||
if aggregations.min:
|
||||
cell_data[f"{var}_min"] = flattened[var].min(dim="z", skipna=True)
|
||||
if aggregations.max:
|
||||
cell_data[f"{var}_max"] = flattened[var].max(dim="z", skipna=True)
|
||||
if len(aggregations.quantiles) > 0:
|
||||
quantile_values = flattened[var].quantile(
|
||||
q=aggregations.quantiles,
|
||||
dim="z",
|
||||
skipna=True,
|
||||
)
|
||||
for q, qv in zip(aggregations.quantiles, quantile_values):
|
||||
q_int = int(q * 100)
|
||||
cell_data[f"{var}_p{q_int}"] = qv
|
||||
# Transform to numpy arrays
|
||||
for key in cell_data:
|
||||
cell_data[key] = cell_data[key].cupy.as_numpy().values
|
||||
return cell_data
|
||||
|
||||
|
||||
@stopwatch("Extracting cell data", log=False)
|
||||
def _extract_cell_data(ds: xr.Dataset, geom: odc.geo.Geometry, aggregations: _Aggregations):
|
||||
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
|
||||
cropped: xr.Dataset = ds.odc.crop(geom).drop_vars("spatial_ref")
|
||||
flattened = cropped.stack(z=spatdims)
|
||||
if flattened.z.size > 3000:
|
||||
flattened = flattened.cupy.as_cupy()
|
||||
cell_data = _agg_cell_data(flattened, aggregations)
|
||||
def _extract_cell_data(cropped: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
|
||||
spatdims = ["latitude", "longitude"] if "latitude" in cropped.dims and "longitude" in cropped.dims else ["y", "x"]
|
||||
flattened = cropped.stack(z=spatdims) # noqa: PD013
|
||||
# if flattened.z.size > 3000:
|
||||
# flattened = flattened.cupy.as_cupy()
|
||||
cell_data = aggregations.agg_cell_data(flattened)
|
||||
return cell_data
|
||||
|
||||
# with np.errstate(divide="ignore", invalid="ignore"):
|
||||
# cell_data = cropped.mean(spatdims)
|
||||
# return {var: cell_data[var].values for var in cell_data.data_vars}
|
||||
|
||||
|
||||
@stopwatch("Extracting split cell data", log=False)
|
||||
def _extract_split_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry], aggregations: _Aggregations):
|
||||
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
|
||||
cropped: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(ds, geoms)]
|
||||
flattened = xr.concat([c.stack(z=spatdims) for c in cropped], dim="z")
|
||||
cell_data = _agg_cell_data(flattened, aggregations)
|
||||
def _extract_split_cell_data(cropped_list: list[xr.Dataset | xr.DataArray], aggregations: _Aggregations):
|
||||
spatdims = (
|
||||
["latitude", "longitude"]
|
||||
if "latitude" in cropped_list[0].dims and "longitude" in cropped_list[0].dims
|
||||
else ["y", "x"]
|
||||
)
|
||||
flattened = xr.concat([c.stack(z=spatdims) for c in cropped_list], dim="z") # noqa: PD013
|
||||
cell_data = aggregations.agg_cell_data(flattened)
|
||||
return cell_data
|
||||
|
||||
# partial_counts = [part.notnull().sum(dim=spatdims) for part in parts]
|
||||
# with np.errstate(divide="ignore", invalid="ignore"):
|
||||
# partial_means = [part.sum(dim=spatdims) for part in parts]
|
||||
# n = xr.concat(partial_counts, dim="part").sum("part")
|
||||
# cell_data = xr.concat(partial_means, dim="part").sum("part") / n
|
||||
# return {var: cell_data[var].values for var in cell_data.data_vars}
|
||||
|
||||
@stopwatch("Cropping (and reading?) cell data", log=False)
|
||||
def _read_cell_data(unaligned: xr.Dataset | xr.DataArray, geom: odc.geo.Geometry) -> xr.Dataset:
|
||||
cropped: xr.Dataset = unaligned.odc.crop(geom, apply_mask=True).drop_vars("spatial_ref")
|
||||
return cropped.compute()
|
||||
|
||||
|
||||
@stopwatch("Cropping (and reading?) split cell data", log=False)
|
||||
def _read_split_cell_data(unaligned: xr.Dataset | xr.DataArray, geoms: list[odc.geo.Geometry]) -> list[xr.Dataset]:
|
||||
cropped: list[xr.Dataset] = [unaligned.odc.crop(geom, apply_mask=True).drop_vars("spatial_ref") for geom in geoms]
|
||||
return [c.compute() for c in cropped]
|
||||
|
||||
|
||||
@stopwatch("Processing cell", log=False)
|
||||
def _process_geom(poly: Polygon, unaligned: xr.Dataset | xr.DataArray, aggregations: _Aggregations):
|
||||
geoms = _get_corrected_geoms((poly, unaligned.odc.geobox, unaligned.odc.crs))
|
||||
if len(geoms) == 0:
|
||||
raise ValueError("No geometries to process")
|
||||
elif len(geoms) == 1:
|
||||
cropped = _read_cell_data(unaligned, geoms[0])
|
||||
cell_data = _extract_cell_data(cropped, aggregations)
|
||||
else:
|
||||
cropped = _read_split_cell_data(unaligned, geoms)
|
||||
cell_data = _extract_split_cell_data(cropped, aggregations)
|
||||
return cell_data
|
||||
|
||||
|
||||
def _partition_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[gpd.GeoDataFrame]:
|
||||
# ! TODO: Adjust for antimeridian crossing, this is only needed if CRS is not 3413
|
||||
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
||||
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
||||
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
|
||||
for i in range(n_partitions):
|
||||
partition = grid_gdf[labels == i]
|
||||
yield partition
|
||||
|
||||
|
||||
class _MemoryProfiler:
|
||||
def __init__(self):
|
||||
self.process = psutil.Process()
|
||||
self._logs = defaultdict(list)
|
||||
|
||||
def log_memory(self, message: str, log: bool = True):
|
||||
mem = self.process.memory_info().rss / 1024**3
|
||||
self._logs[message].append(mem)
|
||||
if log:
|
||||
print(f"[MemoryProfiler] {message}: {mem:.2f} GB")
|
||||
|
||||
def summary(self):
|
||||
summary_lines = ["[MemoryProfiler] Memory usage summary:"]
|
||||
for message, mems in self._logs.items():
|
||||
mems_str = [f"{m:.2f}" for m in mems]
|
||||
summary_lines.append(f" {message}: min={min(mems):.2f} GB, max={max(mems):.2f} GB")
|
||||
summary_lines.append(f" Measurements: {', '.join(mems_str)}")
|
||||
return "\n".join(summary_lines)
|
||||
|
||||
|
||||
memprof = None
|
||||
|
||||
|
||||
def _init_worker():
|
||||
global memprof
|
||||
memprof = _MemoryProfiler()
|
||||
|
||||
|
||||
def _align_partition(
|
||||
grid_partition_gdf: gpd.GeoDataFrame,
|
||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||
aggregations: _Aggregations,
|
||||
pxbuffer: int,
|
||||
):
|
||||
# Strategy for each cell:
|
||||
# 1. Correct the geometry to account for the antimeridian
|
||||
# 2. Cop the dataset and load the data into memory
|
||||
# 3. Aggregate the data
|
||||
# 4. Store the aggregated data in the correct location in the output arrays
|
||||
|
||||
# Steps 1-3 are done in different Processes, since all of these steps are CPU bound
|
||||
# Even the reading part is CPU bound, since it involves applying masks and cropping the data
|
||||
# Also, the backend (icechunk or zarr) can do some computations, like combining chunks together
|
||||
|
||||
# For the Arcticdem on Hex3 grid the avg. time per task is:
|
||||
# 1. ~0.01s
|
||||
# 2. ~1-3s
|
||||
# 3. ~10-20s
|
||||
# And for Hex6:
|
||||
# 1. unmeasurable
|
||||
# 2. ~0.1-0.3s
|
||||
# 3. ~0.05s
|
||||
|
||||
memprof.log_memory("Before reading partial raster", log=False)
|
||||
|
||||
if callable(raster) and not isinstance(raster, xr.Dataset):
|
||||
raster = raster()
|
||||
|
||||
if hasattr(raster, "__dask_graph__"):
|
||||
graph_size = len(raster.__dask_graph__())
|
||||
graph_memory = sys.getsizeof(raster.__dask_graph__()) / 1024**2
|
||||
print(f"Raster graph size: {graph_size} tasks, {graph_memory:.2f} MB")
|
||||
|
||||
others_shape = tuple([raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]])
|
||||
data_shape = (len(grid_partition_gdf), len(raster.data_vars), len(aggregations), *others_shape)
|
||||
data = np.full(data_shape, np.nan, dtype=np.float32)
|
||||
|
||||
partial_extent = odc.geo.BoundingBox(*grid_partition_gdf.total_bounds, crs=grid_partition_gdf.crs)
|
||||
partial_extent = partial_extent.buffered(
|
||||
raster.odc.geobox.resolution.x * pxbuffer,
|
||||
raster.odc.geobox.resolution.y * pxbuffer,
|
||||
) # buffer by pxbuffer pixels
|
||||
with stopwatch("Cropping raster to partition extent", log=False):
|
||||
try:
|
||||
partial_raster = raster.odc.crop(partial_extent, apply_mask=False).compute()
|
||||
except Exception as e:
|
||||
print(f"Error cropping raster to partition extent: {e}")
|
||||
return data
|
||||
|
||||
print(f"{os.getpid()}: Partial raster size: {partial_raster.nbytes / 1e9:.2f} GB ({len(grid_partition_gdf)} cells)")
|
||||
memprof.log_memory("After reading partial raster", log=True)
|
||||
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=60,
|
||||
) as executor:
|
||||
futures = {}
|
||||
for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
||||
# ? Splitting the geometries already here and passing only the cropped data to the worker to
|
||||
# reduce pickling overhead
|
||||
geoms = _get_corrected_geoms((row.geometry, partial_raster.odc.geobox, partial_raster.odc.crs))
|
||||
if len(geoms) == 0:
|
||||
continue
|
||||
elif len(geoms) == 1:
|
||||
cropped = [_read_cell_data(partial_raster, geoms[0])]
|
||||
else:
|
||||
cropped = _read_split_cell_data(partial_raster, geoms)
|
||||
futures[executor.submit(_extract_split_cell_data, cropped, aggregations)] = i
|
||||
|
||||
# Clean up immediately after submitting
|
||||
del geoms, cropped # Added this line
|
||||
|
||||
for future in as_completed(futures):
|
||||
i = futures[future]
|
||||
try:
|
||||
cell_data = future.result()
|
||||
except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
print(f"Error processing cell {i}: {e}")
|
||||
continue
|
||||
data[i, ...] = cell_data
|
||||
del cell_data, futures[future]
|
||||
|
||||
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows()):
|
||||
# try:
|
||||
# cell_data = _process_geom(row.geometry, partial_raster, aggregations)
|
||||
# except (SystemError, SystemExit, KeyboardInterrupt) as e:
|
||||
# raise e
|
||||
# except Exception as e:
|
||||
# print(f"Error processing cell {row['cell_id']}: {e}")
|
||||
# continue
|
||||
# data[i, ...] = cell_data
|
||||
|
||||
partial_raster.close()
|
||||
raster.close()
|
||||
del partial_raster, raster, future, futures
|
||||
gc.collect()
|
||||
memprof.log_memory("After cleaning", log=True)
|
||||
|
||||
print("Finished processing partition")
|
||||
print("### Stopwatch summary ###\n")
|
||||
print(stopwatch.summary())
|
||||
print("### Memory summary ###\n")
|
||||
print(memprof.summary())
|
||||
print("#########################")
|
||||
# with ProcessPoolExecutor(
|
||||
# max_workers=max_workers,
|
||||
# initializer=_init_partial_raster_global,
|
||||
# initargs=(raster, grid_partition_gdf, pxbuffer),
|
||||
# ) as executor:
|
||||
# futures = {
|
||||
# executor.submit(_process_geom_global, row.geometry, aggregations): i
|
||||
# for i, (idx, row) in enumerate(grid_partition_gdf.iterrows())
|
||||
# }
|
||||
|
||||
# others_shape = tuple(
|
||||
# [raster.sizes[dim] for dim in raster.dims if dim not in ["y", "x", "latitude", "longitude"]]
|
||||
# )
|
||||
# data_shape = (len(futures), len(raster.data_vars), len(aggregations), *others_shape)
|
||||
# print(f"Allocating partial data array of shape {data_shape}")
|
||||
# data = np.full(data_shape, np.nan, dtype=np.float32)
|
||||
|
||||
# lap = time.time()
|
||||
# for future in track(
|
||||
# as_completed(futures),
|
||||
# total=len(futures),
|
||||
# description="Spatially aggregating ERA5 data...",
|
||||
# ):
|
||||
# print(f"time since last cell: {time.time() - lap:.2f}s")
|
||||
# i = futures[future]
|
||||
# try:
|
||||
# cell_data = future.result()
|
||||
# data[i, ...] = cell_data
|
||||
# except Exception as e:
|
||||
# print(f"Error processing cell {i}: {e}")
|
||||
# continue
|
||||
# finally:
|
||||
# # Try to free memory
|
||||
# del futures[future], future
|
||||
# if "cell_data" in locals():
|
||||
# del cell_data
|
||||
# gc.collect()
|
||||
# lap = time.time()
|
||||
return data
|
||||
|
||||
|
||||
@stopwatch("Aligning data with grid")
|
||||
def _align_data(
|
||||
grid_gdf: gpd.GeoDataFrame,
|
||||
unaligned: xr.Dataset,
|
||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||
aggregations: _Aggregations,
|
||||
) -> dict[str, np.ndarray]:
|
||||
# Persist the dataset, as all the operations here MUST NOT be lazy
|
||||
unaligned = unaligned.load()
|
||||
n_partitions: int,
|
||||
max_workers: int,
|
||||
pxbuffer: int,
|
||||
):
|
||||
partial_data = {}
|
||||
|
||||
cell_geometries = get_corrected_geometries(grid_gdf, unaligned.odc.geobox)
|
||||
other_dims_shape = tuple(
|
||||
[unaligned.sizes[dim] for dim in unaligned.dims if dim not in ["y", "x", "latitude", "longitude"]]
|
||||
)
|
||||
data_shape = (len(cell_geometries), *other_dims_shape)
|
||||
data = {var: np.full(data_shape, np.nan, dtype=np.float32) for var in aggregations.varnames(unaligned.data_vars)}
|
||||
with ProcessPoolExecutor(max_workers=10) as executor:
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=max_workers,
|
||||
initializer=_init_worker,
|
||||
# initializer=_init_raster_global,
|
||||
# initargs=(raster,),
|
||||
) as executor:
|
||||
futures = {}
|
||||
for i, geoms in track(
|
||||
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
|
||||
):
|
||||
if len(geoms) == 0:
|
||||
continue
|
||||
elif len(geoms) == 1:
|
||||
geom = geoms[0]
|
||||
# Reduce the amount of data needed to be sent to the worker
|
||||
# Since we dont mask the data, only isel operations are done here
|
||||
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
|
||||
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
|
||||
fut = executor.submit(_extract_cell_data, unaligned_subset, geom, aggregations)
|
||||
else:
|
||||
# Same as above but for multiple parts
|
||||
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
|
||||
fut = executor.submit(_extract_split_cell_data, unaligned_subsets, geoms, aggregations)
|
||||
futures[fut] = i
|
||||
for i, grid_partition in enumerate(_partition_grid(grid_gdf, n_partitions)):
|
||||
futures[
|
||||
executor.submit(
|
||||
_align_partition,
|
||||
grid_partition,
|
||||
raster.copy() if isinstance(raster, xr.Dataset) else raster,
|
||||
aggregations,
|
||||
pxbuffer,
|
||||
)
|
||||
] = i
|
||||
|
||||
print("Submitted all partitions, waiting for results...")
|
||||
|
||||
for future in track(
|
||||
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
|
||||
as_completed(futures),
|
||||
total=len(futures),
|
||||
description="Processing grid partitions...",
|
||||
):
|
||||
try:
|
||||
i = futures[future]
|
||||
cell_data = future.result()
|
||||
for var in unaligned.data_vars:
|
||||
data[var][i, :] = cell_data[var]
|
||||
print(f"Processed partition {i + 1}/{len(futures)}")
|
||||
part_data = future.result()
|
||||
partial_data[i] = part_data
|
||||
except Exception as e:
|
||||
print(f"Error processing partition {i}: {e}")
|
||||
raise e
|
||||
|
||||
data = np.concatenate([partial_data[i] for i in range(len(partial_data))], axis=0)
|
||||
return data
|
||||
|
||||
|
||||
def aggregate_raster_into_grid(
|
||||
raster: xr.Dataset,
|
||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||
grid_gdf: gpd.GeoDataFrame,
|
||||
aggregations: _Aggregations,
|
||||
grid: Literal["hex", "healpix"],
|
||||
level: int,
|
||||
n_partitions: int = 20,
|
||||
max_workers: int = 5,
|
||||
pxbuffer: int = 15,
|
||||
):
|
||||
aligned = _align_data(grid_gdf, raster, aggregations)
|
||||
"""Aggregate raster data into grid cells.
|
||||
|
||||
Args:
|
||||
raster (xr.Dataset | Callable[[], xr.Dataset]): Raster data or a function that returns it.
|
||||
grid_gdf (gpd.GeoDataFrame): The grid to aggregate into.
|
||||
aggregations (_Aggregations): The aggregations to perform.
|
||||
grid (Literal["hex", "healpix"]): The type of grid to use.
|
||||
level (int): The level of the grid.
|
||||
n_partitions (int, optional): Number of partitions to divide the grid into. Defaults to 20.
|
||||
max_workers (int, optional): Maximum number of worker processes. Defaults to 5.
|
||||
pxbuffer (int, optional): Pixel buffer around each grid cell. Defaults to 15.
|
||||
|
||||
Returns:
|
||||
xr.Dataset: Aggregated data aligned with the grid.
|
||||
|
||||
"""
|
||||
aligned = _align_data(
|
||||
grid_gdf,
|
||||
raster,
|
||||
aggregations,
|
||||
n_partitions=n_partitions,
|
||||
max_workers=max_workers,
|
||||
pxbuffer=pxbuffer,
|
||||
)
|
||||
# Dims of aligned: (cell_ids, variables, aggregations, other dims...)
|
||||
|
||||
if callable(raster) and not isinstance(raster, xr.Dataset):
|
||||
raster = raster()
|
||||
|
||||
cell_ids = grids.get_cell_ids(grid, level)
|
||||
|
||||
dims = ["cell_ids"]
|
||||
coords = {"cell_ids": cell_ids}
|
||||
for dim in raster.dims:
|
||||
if dim not in ["y", "x", "latitude", "longitude"]:
|
||||
dims = ["cell_ids", "variables", "aggregations"]
|
||||
coords = {"cell_ids": cell_ids, "variables": list(raster.data_vars), "aggregations": aggregations.aggnames()}
|
||||
for dim in set(raster.dims) - {"y", "x", "latitude", "longitude"}:
|
||||
dims.append(dim)
|
||||
coords[dim] = raster.coords[dim]
|
||||
|
||||
data_vars = {var: (dims, values) for var, values in aligned.items()}
|
||||
ongrid = xr.Dataset(
|
||||
data_vars,
|
||||
coords=coords,
|
||||
)
|
||||
ongrid = xr.DataArray(aligned, dims=dims, coords=coords).to_dataset("variables")
|
||||
|
||||
gridinfo = {
|
||||
"grid_name": "h3" if grid == "hex" else grid,
|
||||
"level": level,
|
||||
|
|
@ -247,8 +509,7 @@ def aggregate_raster_into_grid(
|
|||
gridinfo["indexing_scheme"] = "nested"
|
||||
ongrid.cell_ids.attrs = gridinfo
|
||||
for var in raster.data_vars:
|
||||
for v in aggregations.varnames(var):
|
||||
ongrid[v].attrs = raster[var].attrs
|
||||
ongrid[var].attrs = raster[var].attrs
|
||||
|
||||
ongrid = xdggs.decode(ongrid)
|
||||
return ongrid
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ def download(grid: Literal["hex", "healpix"], level: int):
|
|||
return feature.set(mean_dict)
|
||||
|
||||
# Process grid in batches of 100
|
||||
batch_size = 100
|
||||
batch_size = 50
|
||||
all_results = []
|
||||
n_batches = len(grid_gdf) // batch_size
|
||||
for batch_num, batch_grid in track(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import datetime
|
||||
import multiprocessing as mp
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Literal
|
||||
|
|
@ -9,6 +10,7 @@ import cupyx.scipy.signal
|
|||
import cyclopts
|
||||
import dask.array
|
||||
import dask.distributed as dd
|
||||
import icechunk
|
||||
import icechunk.xarray
|
||||
import numpy as np
|
||||
import smart_geocubes
|
||||
|
|
@ -301,18 +303,23 @@ def enrich():
|
|||
print("Enrichment complete.")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def aggregate(grid: Literal["hex", "healpix"], level: int):
|
||||
with stopwatch(f"Loading {grid} grid at level {level}"):
|
||||
grid_gdf = grids.open(grid, level)
|
||||
# ? Mask out water, since we don't want to aggregate over oceans
|
||||
grid_gdf = watermask.clip_grid(grid_gdf)
|
||||
|
||||
def _open_adem():
|
||||
arcticdem_store = get_arcticdem_stores()
|
||||
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||
adem = accessor.open_xarray()
|
||||
assert {"x", "y"} == set(adem.dims)
|
||||
assert adem.odc.crs == "EPSG:3413"
|
||||
return adem
|
||||
|
||||
|
||||
@cli.command()
|
||||
def aggregate(grid: Literal["hex", "healpix"], level: int, max_workers: int = 20):
|
||||
mp.set_start_method("forkserver", force=True)
|
||||
|
||||
with stopwatch(f"Loading {grid} grid at level {level}"):
|
||||
grid_gdf = grids.open(grid, level)
|
||||
# ? Mask out water, since we don't want to aggregate over oceans
|
||||
grid_gdf = watermask.clip_grid(grid_gdf)
|
||||
|
||||
aggregations = _Aggregations(
|
||||
mean=True,
|
||||
|
|
@ -320,13 +327,29 @@ def aggregate(grid: Literal["hex", "healpix"], level: int):
|
|||
std=True,
|
||||
min=True,
|
||||
max=True,
|
||||
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
|
||||
median=True,
|
||||
quantiles=(0.01, 0.05, 0.25, 0.75, 0.95, 0.99),
|
||||
)
|
||||
|
||||
aggregated = aggregate_raster_into_grid(
|
||||
_open_adem,
|
||||
grid_gdf,
|
||||
aggregations,
|
||||
grid,
|
||||
level,
|
||||
max_workers=max_workers,
|
||||
n_partitions=200, # Results in 10GB large patches
|
||||
pxbuffer=30,
|
||||
)
|
||||
aggregated = aggregate_raster_into_grid(adem, grid_gdf, aggregations, grid, level)
|
||||
store = get_arcticdem_stores(grid, level)
|
||||
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
|
||||
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
||||
|
||||
print("Aggregation complete.")
|
||||
|
||||
print("### Finished ArcticDEM processing ###")
|
||||
stopwatch.summary()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue