Refactor spatial aggregations

This commit is contained in:
Tobias Hölzer 2025-11-22 22:18:45 +01:00
parent 1a71883999
commit e5382670ec
5 changed files with 457 additions and 240 deletions

View file

@ -0,0 +1,254 @@
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Literal
import cupy_xarray
import geopandas as gpd
import numpy as np
import odc.geo.geobox
import shapely
import shapely.ops
import xarray as xr
import xdggs
import xvec
from rich.progress import track
from shapely.geometry import LineString, Polygon
from stopuhr import stopwatch
from xdggs.healpix import HealpixInfo
from entropice import grids
@dataclass
class _Aggregations:
mean: bool = True
sum: bool = False
std: bool = False
min: bool = False
max: bool = False
quantiles: list[float] = field(default_factory=lambda: [])
def varnames(self, vars: list[str] | str) -> list[str]:
if isinstance(vars, str):
vars = [vars]
agg_vars = []
for var in vars:
if self.mean:
agg_vars.append(f"{var}_mean")
if self.sum:
agg_vars.append(f"{var}_sum")
if self.std:
agg_vars.append(f"{var}_std")
if self.min:
agg_vars.append(f"{var}_min")
if self.max:
agg_vars.append(f"{var}_max")
for q in self.quantiles:
q_int = int(q * 100)
agg_vars.append(f"{var}_p{q_int}")
return agg_vars
def _crosses_antimeridian(geom: Polygon) -> bool:
coords = shapely.get_coordinates(geom)
crosses_any_meridian = (coords[:, 0] > 0).any() and (coords[:, 0] < 0).any()
return crosses_any_meridian and abs(coords[:, 0]).max() > 90
def _split_antimeridian_cell(geom: Polygon) -> list[Polygon]:
# Assumes that it is a antimeridian hex
coords = shapely.get_coordinates(geom)
for i in range(coords.shape[0]):
if coords[i, 0] < 0:
coords[i, 0] += 360
geom = Polygon(coords)
antimeridian = LineString([[180, -90], [180, 90]])
polys = shapely.ops.split(geom, antimeridian)
return list(polys.geoms)
def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
enclosing = geobox.enclosing(geom)
x, y = enclosing.shape
if x <= 1 or y <= 1:
return False
roi: tuple[slice, slice] = geobox.overlap_roi(enclosing)
roix, roiy = roi
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
@stopwatch("Getting corrected geometries", log=False)
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]:
geom, gbox, crs = inp
# cell.geometry is a shapely Polygon
if not _crosses_antimeridian(geom):
geoms = [geom]
# Split geometry in case it crossed antimeridian
else:
geoms = _split_antimeridian_cell(geom)
geoms = [odc.geo.Geometry(g, crs=crs) for g in geoms]
geoms = list(filter(lambda g: _check_geom(gbox, g), geoms))
return geoms
@stopwatch("Correcting geometries")
def get_corrected_geometries(grid_gdf: gpd.GeoDataFrame, gbox: odc.geo.geobox.GeoBox):
"""Get corrected geometries for antimeridian-crossing polygons.
Args:
grid_gdf (gpd.GeoDataFrame): Grid GeoDataFrame.
gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference.
Returns:
list[list[odc.geo.Geometry]]: List of corrected, georeferenced geometries.
"""
with ProcessPoolExecutor(max_workers=20) as executor:
cell_geometries = list(
executor.map(_get_corrected_geoms, [(row.geometry, gbox, grid_gdf.crs) for _, row in grid_gdf.iterrows()])
)
return cell_geometries
@stopwatch("Aggregating cell data", log=False)
def _agg_cell_data(flattened: xr.Dataset, aggregations: _Aggregations):
cell_data = {}
for var in flattened.data_vars:
if aggregations.mean:
cell_data[f"{var}_mean"] = flattened[var].mean(dim="z", skipna=True)
if aggregations.sum:
cell_data[f"{var}_sum"] = flattened[var].sum(dim="z", skipna=True)
if aggregations.std:
cell_data[f"{var}_std"] = flattened[var].std(dim="z", skipna=True)
if aggregations.min:
cell_data[f"{var}_min"] = flattened[var].min(dim="z", skipna=True)
if aggregations.max:
cell_data[f"{var}_max"] = flattened[var].max(dim="z", skipna=True)
if len(aggregations.quantiles) > 0:
quantile_values = flattened[var].quantile(
q=aggregations.quantiles,
dim="z",
skipna=True,
)
for q, qv in zip(aggregations.quantiles, quantile_values):
q_int = int(q * 100)
cell_data[f"{var}_p{q_int}"] = qv
# Transform to numpy arrays
for key in cell_data:
cell_data[key] = cell_data[key].cupy.as_numpy().values
return cell_data
@stopwatch("Extracting cell data", log=False)
def _extract_cell_data(ds: xr.Dataset, geom: odc.geo.Geometry, aggregations: _Aggregations):
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
cropped: xr.Dataset = ds.odc.crop(geom).drop_vars("spatial_ref")
flattened = cropped.stack(z=spatdims)
if flattened.z.size > 3000:
flattened = flattened.cupy.as_cupy()
cell_data = _agg_cell_data(flattened, aggregations)
return cell_data
# with np.errstate(divide="ignore", invalid="ignore"):
# cell_data = cropped.mean(spatdims)
# return {var: cell_data[var].values for var in cell_data.data_vars}
@stopwatch("Extracting split cell data", log=False)
def _extract_split_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry], aggregations: _Aggregations):
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
cropped: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(ds, geoms)]
flattened = xr.concat([c.stack(z=spatdims) for c in cropped], dim="z")
cell_data = _agg_cell_data(flattened, aggregations)
return cell_data
# partial_counts = [part.notnull().sum(dim=spatdims) for part in parts]
# with np.errstate(divide="ignore", invalid="ignore"):
# partial_means = [part.sum(dim=spatdims) for part in parts]
# n = xr.concat(partial_counts, dim="part").sum("part")
# cell_data = xr.concat(partial_means, dim="part").sum("part") / n
# return {var: cell_data[var].values for var in cell_data.data_vars}
@stopwatch("Aligning data with grid")
def _align_data(
grid_gdf: gpd.GeoDataFrame,
unaligned: xr.Dataset,
aggregations: _Aggregations,
) -> dict[str, np.ndarray]:
# Persist the dataset, as all the operations here MUST NOT be lazy
unaligned = unaligned.load()
cell_geometries = get_corrected_geometries(grid_gdf, unaligned.odc.geobox)
other_dims_shape = tuple(
[unaligned.sizes[dim] for dim in unaligned.dims if dim not in ["y", "x", "latitude", "longitude"]]
)
data_shape = (len(cell_geometries), *other_dims_shape)
data = {var: np.full(data_shape, np.nan, dtype=np.float32) for var in aggregations.varnames(unaligned.data_vars)}
with ProcessPoolExecutor(max_workers=10) as executor:
futures = {}
for i, geoms in track(
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
):
if len(geoms) == 0:
continue
elif len(geoms) == 1:
geom = geoms[0]
# Reduce the amount of data needed to be sent to the worker
# Since we dont mask the data, only isel operations are done here
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
fut = executor.submit(_extract_cell_data, unaligned_subset, geom, aggregations)
else:
# Same as above but for multiple parts
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
fut = executor.submit(_extract_split_cell_data, unaligned_subsets, geoms, aggregations)
futures[fut] = i
for future in track(
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
):
i = futures[future]
cell_data = future.result()
for var in unaligned.data_vars:
data[var][i, :] = cell_data[var]
return data
def aggregate_raster_into_grid(
raster: xr.Dataset,
grid_gdf: gpd.GeoDataFrame,
aggregations: _Aggregations,
grid: Literal["hex", "healpix"],
level: int,
):
aligned = _align_data(grid_gdf, raster, aggregations)
cell_ids = grids.get_cell_ids(grid, level)
dims = ["cell_ids"]
coords = {"cell_ids": cell_ids}
for dim in raster.dims:
if dim not in ["y", "x", "latitude", "longitude"]:
dims.append(dim)
coords[dim] = raster.coords[dim]
data_vars = {var: (dims, values) for var, values in aligned.items()}
ongrid = xr.Dataset(
data_vars,
coords=coords,
)
gridinfo = {
"grid_name": "h3" if grid == "hex" else grid,
"level": level,
}
if grid == "healpix":
gridinfo["indexing_scheme"] = "nested"
ongrid.cell_ids.attrs = gridinfo
for var in raster.data_vars:
for v in aggregations.varnames(var):
ongrid[v].attrs = raster[var].attrs
ongrid = xdggs.decode(ongrid)
return ongrid

View file

@ -1,6 +1,7 @@
import datetime import datetime
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import Literal
import cupy as cp import cupy as cp
import cupy_xarray import cupy_xarray
@ -12,6 +13,7 @@ import icechunk.xarray
import numpy as np import numpy as np
import smart_geocubes import smart_geocubes
import xarray as xr import xarray as xr
import xdggs
import xrspatial import xrspatial
import zarr import zarr
from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt
@ -22,7 +24,9 @@ from xrspatial.curvature import _run_cupy as curvature_cupy
from xrspatial.slope import _run_cupy as slope_cupy from xrspatial.slope import _run_cupy as slope_cupy
from zarr.codecs import BloscCodec from zarr.codecs import BloscCodec
from entropice.paths import arcticdem_store from entropice import codecs, grids, watermask
from entropice.aggregators import _Aggregations, aggregate_raster_into_grid
from entropice.paths import get_arcticdem_stores
traceback.install(show_locals=True, suppress=[cyclopts]) traceback.install(show_locals=True, suppress=[cyclopts])
pretty.install() pretty.install()
@ -34,6 +38,7 @@ cli = cyclopts.App(name="arcticdem")
@cli.command() @cli.command()
def download(): def download():
arcticdem_store = get_arcticdem_stores()
adem = smart_geocubes.ArcticDEM32m(arcticdem_store) adem = smart_geocubes.ArcticDEM32m(arcticdem_store)
adem.create(exists_ok=True) adem.create(exists_ok=True)
with stopwatch("Download ArcticDEM data"): with stopwatch("Download ArcticDEM data"):
@ -172,7 +177,7 @@ def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) ->
tpi_large = tpi_cupy(chunk, large_kernels) tpi_large = tpi_cupy(chunk, large_kernels)
# Slope # Slope
slope = slope_cupy(chunk, res, res) slope = slope_cupy(chunk, res, res)
# Aspect # Aspect & correction
aspect = aspect_cupy(chunk) aspect = aspect_cupy(chunk)
xx, yy = _get_xy_chunk(chunk, x, y, block_info) xx, yy = _get_xy_chunk(chunk, x, y, block_info)
aspect_correction = cp.arctan2(yy, xx) * (180 / cp.pi) + 90 aspect_correction = cp.arctan2(yy, xx) * (180 / cp.pi) + 90
@ -235,12 +240,6 @@ def _enrich(terrain: xr.DataArray):
enriched_da = dask.array.map_overlap( enriched_da = dask.array.map_overlap(
_enrich_chunk, _enrich_chunk,
terrain.data, terrain.data,
# dask.array.from_array(terrain.y.data.reshape(-1, 1), chunks=(terrain.data.chunks[0][0], 3600)).repeat(
# terrain.x.size, axis=1
# ),
# dask.array.from_array(terrain.x.data.reshape(1, -1), chunks=(3600, terrain.data.chunks[1][0])).repeat(
# terrain.y.size, axis=0
# ),
x=terrain.x.to_numpy(), x=terrain.x.to_numpy(),
y=terrain.y.to_numpy(), y=terrain.y.to_numpy(),
depth=15, # large_kernels.size_px depth=15, # large_kernels.size_px
@ -273,14 +272,13 @@ def enrich():
print(client) print(client)
print(client.dashboard_link) print(client.dashboard_link)
arcticdem_store = get_arcticdem_stores()
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store) accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
# Garbage collect from previous runs # Garbage collect from previous runs
accessor.repo.garbage_collect(datetime.datetime.now(datetime.UTC)) accessor.repo.garbage_collect(datetime.datetime.now(datetime.UTC))
adem = accessor.open_xarray() adem = accessor.open_xarray()
# session = adem.repo.readonly_session("main")
# adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref")
del adem.y.attrs["_FillValue"] del adem.y.attrs["_FillValue"]
del adem.x.attrs["_FillValue"] del adem.x.attrs["_FillValue"]
enriched, new_features = _enrich(adem.dem) enriched, new_features = _enrich(adem.dem)
@ -289,10 +287,9 @@ def enrich():
adem[feature] = enriched.sel(feature=feature) adem[feature] = enriched.sel(feature=feature)
print(adem[new_features]) print(adem[new_features])
# subset = adem[new_features].isel(x=slice(190800, 220000), y=slice(61200, 100000)) # subset = adem[new_features].isel(x=slice(190800, 220000), y=slice(61200, 100000))
encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features}
# subset.to_zarr("test2.zarr", mode="w", encoding=encodings)
session = accessor.repo.writable_session("main") session = accessor.repo.writable_session("main")
encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features}
icechunk.xarray.to_icechunk( icechunk.xarray.to_icechunk(
adem[new_features], adem[new_features],
session, session,
@ -304,5 +301,32 @@ def enrich():
print("Enrichment complete.") print("Enrichment complete.")
@cli.command()
def aggregate(grid: Literal["hex", "healpix"], level: int):
with stopwatch(f"Loading {grid} grid at level {level}"):
grid_gdf = grids.open(grid, level)
# ? Mask out water, since we don't want to aggregate over oceans
grid_gdf = watermask.clip_grid(grid_gdf)
arcticdem_store = get_arcticdem_stores()
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
adem = accessor.open_xarray()
assert {"x", "y"} == set(adem.dims)
assert adem.odc.crs == "EPSG:3413"
aggregations = _Aggregations(
mean=True,
sum=False,
std=True,
min=True,
max=True,
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
)
aggregated = aggregate_raster_into_grid(adem, grid_gdf, aggregations, grid, level)
store = get_arcticdem_stores(grid, level)
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View file

@ -1,4 +1,4 @@
# ruff: noqa: PD003, PD011 # ruff: noqa: PD003
"""Download and preprocess ERA5 data. """Download and preprocess ERA5 data.
Variables of Interest: Variables of Interest:
@ -79,7 +79,6 @@ Date: June to October 2025
import cProfile import cProfile
import time import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Literal from typing import Literal
import cyclopts import cyclopts
@ -88,18 +87,15 @@ import numpy as np
import odc.geo import odc.geo
import odc.geo.xr import odc.geo.xr
import pandas as pd import pandas as pd
import shapely
import shapely.ops
import ultraplot as uplt import ultraplot as uplt
import xarray as xr import xarray as xr
import xdggs import xdggs
import xvec import xvec
from rich import pretty, print, traceback from rich import pretty, print, traceback
from rich.progress import track
from shapely.geometry import LineString, Polygon
from stopuhr import stopwatch from stopuhr import stopwatch
from entropice import codecs, grids, watermask from entropice import codecs, grids, watermask
from entropice.aggregators import _Aggregations, aggregate_raster_into_grid
from entropice.paths import FIGURES_DIR, get_era5_stores from entropice.paths import FIGURES_DIR, get_era5_stores
from entropice.xvec import to_xvec from entropice.xvec import to_xvec
@ -647,145 +643,10 @@ def viz(
# =========================== # ===========================
def _crosses_antimeridian(geom: Polygon) -> bool:
coords = shapely.get_coordinates(geom)
crosses_any_meridian = (coords[:, 0] > 0).any() and (coords[:, 0] < 0).any()
return crosses_any_meridian and abs(coords[:, 0]).max() > 90
def _split_antimeridian_cell(geom: Polygon) -> list[Polygon]:
# Assumes that it is a antimeridian hex
coords = shapely.get_coordinates(geom)
for i in range(coords.shape[0]):
if coords[i, 0] < 0:
coords[i, 0] += 360
geom = Polygon(coords)
antimeridian = LineString([[180, -90], [180, 90]])
polys = shapely.ops.split(geom, antimeridian)
return list(polys.geoms)
def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
enclosing = geobox.enclosing(geom)
x, y = enclosing.shape
if x <= 1 or y <= 1:
return False
roi: tuple[slice, slice] = geobox.overlap_roi(enclosing)
roix, roiy = roi
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
@stopwatch("Getting corrected geometries", log=False)
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox]) -> list[odc.geo.Geometry]:
"""Get corrected geometries for antimeridian-crossing polygons.
Args:
geom (Polygon): Input polygon geometry.
gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference.
Returns:
list[odc.geo.Geometry]: List of corrected, georeferenced geometries.
"""
geom, gbox = inp
# cell.geometry is a shapely Polygon
if not _crosses_antimeridian(geom):
geoms = [geom]
# Split geometry in case it crossed antimeridian
else:
geoms = _split_antimeridian_cell(geom)
geoms = [odc.geo.Geometry(g, crs="epsg:4326") for g in geoms]
geoms = list(filter(lambda g: _check_geom(gbox, g), geoms))
return geoms
def _correct_longs(ds: xr.Dataset) -> xr.Dataset: def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude") return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude")
@stopwatch("Extracting cell data", log=False)
def _extract_cell_data(ds, geom):
cropped = ds.odc.crop(geom).drop_vars("spatial_ref")
with np.errstate(divide="ignore", invalid="ignore"):
cell_data = cropped.mean(["latitude", "longitude"])
return {var: cell_data[var].values for var in cell_data.data_vars}
@stopwatch("Extracting split cell data", log=False)
def _extract_split_cell_data(dss: xr.Dataset, geoms):
parts: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(dss, geoms)]
partial_counts = [part.notnull().sum(dim=["latitude", "longitude"]) for part in parts]
with np.errstate(divide="ignore", invalid="ignore"):
partial_means = [part.sum(["latitude", "longitude"]) for part in parts]
n = xr.concat(partial_counts, dim="part").sum("part")
cell_data = xr.concat(partial_means, dim="part").sum("part") / n
return {var: cell_data[var].values for var in cell_data.data_vars}
@stopwatch("Aligning data")
def _align_data(cell_geometries: list[list[odc.geo.Geometry]], unaligned: xr.Dataset) -> dict[str, np.ndarray]:
# Persist the dataset, as all the operations here MUST NOT be lazy
unaligned = unaligned.load()
data = {
var: np.full((len(cell_geometries), len(unaligned.time)), np.nan, dtype=np.float32)
for var in unaligned.data_vars
}
with ProcessPoolExecutor(max_workers=10) as executor:
futures = {}
for i, geoms in track(
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
):
if len(geoms) == 0:
continue
elif len(geoms) == 1:
geom = geoms[0]
# Reduce the amount of data needed to be sent to the worker
# Since we dont mask the data, only isel operations are done here
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
futures[executor.submit(_extract_cell_data, unaligned_subset, geom)] = i
else:
# Same as above but for multiple parts
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
futures[executor.submit(_extract_split_cell_data, unaligned_subsets, geoms)] = i
for future in track(
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
):
i = futures[future]
cell_data = future.result()
for var in unaligned.data_vars:
data[var][i, :] = cell_data[var]
return data
@stopwatch("Creating aligned dataset", log=False)
def _create_aligned(
ds: xr.Dataset, data: dict[str, np.ndarray], grid: Literal["hex", "healpix"], level: int
) -> xr.Dataset:
cell_ids = grids.get_cell_ids(grid, level)
data_vars = {var: (["cell_ids", "time"], values) for var, values in data.items()}
aligned = xr.Dataset(
data_vars,
coords={"cell_ids": cell_ids, "time": ds.time},
)
gridinfo = {
"grid_name": "h3" if grid == "hex" else grid,
"level": level,
}
if grid == "healpix":
gridinfo["indexing_scheme"] = "nested"
aligned.cell_ids.attrs = gridinfo
for var in ds.data_vars:
aligned[var].attrs = ds[var].attrs
aligned = aligned.chunk({"cell_ids": min(len(aligned.cell_ids), 10000), "time": len(aligned.time)})
aligned = xdggs.decode(aligned)
return aligned
@cli.command @cli.command
def spatial_agg( def spatial_agg(
grid: Literal["hex", "healpix"], grid: Literal["hex", "healpix"],
@ -816,18 +677,28 @@ def spatial_agg(
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}" assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
unaligned = _correct_longs(unaligned) unaligned = _correct_longs(unaligned)
with stopwatch("Precomputing cell geometries"): # with stopwatch("Precomputing cell geometries"):
with ProcessPoolExecutor(max_workers=20) as executor: # with ProcessPoolExecutor(max_workers=20) as executor:
cell_geometries = list( # cell_geometries = list(
executor.map( # executor.map(
_get_corrected_geoms, # _get_corrected_geoms,
[(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()], # [(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()],
) # )
) # )
# data = _align_data(cell_geometries, unaligned)
# aggregated = _create_aligned(unaligned, data, grid, level)
data = _align_data(cell_geometries, unaligned) aggregations = _Aggregations(
mean=True,
sum=False,
std=True,
min=True,
max=True,
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
)
aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level)
aggregated = _create_aligned(unaligned, data, grid, level) aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})
store = get_era5_stores(agg, grid, level) store = get_era5_stores(agg, grid, level)
with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"): with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"):
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated)) aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))

View file

@ -13,6 +13,7 @@ GRIDS_DIR = DATA_DIR / "grids"
FIGURES_DIR = Path("figures") FIGURES_DIR = Path("figures")
DARTS_DIR = DATA_DIR / "darts" DARTS_DIR = DATA_DIR / "darts"
ERA5_DIR = DATA_DIR / "era5" ERA5_DIR = DATA_DIR / "era5"
ARCTICDEM_DIR = DATA_DIR / "arcticdem"
EMBEDDINGS_DIR = DATA_DIR / "embeddings" EMBEDDINGS_DIR = DATA_DIR / "embeddings"
WATERMASK_DIR = DATA_DIR / "watermask" WATERMASK_DIR = DATA_DIR / "watermask"
TRAINING_DIR = DATA_DIR / "training" TRAINING_DIR = DATA_DIR / "training"
@ -22,12 +23,12 @@ GRIDS_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR.mkdir(parents=True, exist_ok=True) FIGURES_DIR.mkdir(parents=True, exist_ok=True)
DARTS_DIR.mkdir(parents=True, exist_ok=True) DARTS_DIR.mkdir(parents=True, exist_ok=True)
ERA5_DIR.mkdir(parents=True, exist_ok=True) ERA5_DIR.mkdir(parents=True, exist_ok=True)
ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True)
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
WATERMASK_DIR.mkdir(parents=True, exist_ok=True) WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
TRAINING_DIR.mkdir(parents=True, exist_ok=True) TRAINING_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True) RESULTS_DIR.mkdir(parents=True, exist_ok=True)
arcticdem_store = DATA_DIR / "arcticdem32m.icechunk.zarr"
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp" watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
@ -83,6 +84,17 @@ def get_era5_stores(
return aligned_path return aligned_path
def get_arcticdem_stores(
grid: Literal["hex", "healpix"] | None = None,
level: int | None = None,
):
if grid is None or level is None:
return DATA_DIR / "arcticdem32m.icechunk.zarr"
gridname = _get_gridname(grid, level)
aligned_path = ARCTICDEM_DIR / f"{gridname}_dem.zarr"
return aligned_path
def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path: def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet" dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet"

View file

@ -1,5 +1,5 @@
# ruff: noqa: N806 # ruff: noqa: N806
"""Training dataset preparation and model training.""" """Training of classification models training."""
import pickle import pickle
from typing import Literal from typing import Literal
@ -10,12 +10,15 @@ import pandas as pd
import toml import toml
import torch import torch
import xarray as xr import xarray as xr
from cuml.ensemble import RandomForestClassifier
from cuml.neighbors import KNeighborsClassifier
from entropy import ESPAClassifier from entropy import ESPAClassifier
from rich import pretty, traceback from rich import pretty, traceback
from scipy.stats import loguniform, randint from scipy.stats import loguniform, randint
from sklearn import set_config from sklearn import set_config
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
from stopuhr import stopwatch from stopuhr import stopwatch
from xgboost import XGBClassifier
from entropice.inference import predict_proba from entropice.inference import predict_proba
from entropice.paths import ( from entropice.paths import (
@ -29,13 +32,13 @@ pretty.install()
set_config(array_api_dispatch=True) set_config(array_api_dispatch=True)
def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "multi"] = "binary"): def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "count", "density"] = "binary"):
"""Create X and y data from the training dataset. """Create X and y data from the training dataset.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary". task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
Returns: Returns:
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names. Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
@ -51,7 +54,7 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b
if task == "binary": if task == "binary":
labels = ["No RTS", "RTS"] labels = ["No RTS", "RTS"]
y_data = data.loc[X_data.index, "darts_has_rts"] y_data = data.loc[X_data.index, "darts_has_rts"]
else: elif task == "count":
# Put into n categories (log scaled) # Put into n categories (log scaled)
y_data = data.loc[X_data.index, "darts_rts_count"] y_data = data.loc[X_data.index, "darts_rts_count"]
n_categories = 5 n_categories = 5
@ -63,6 +66,17 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b
y_data = pd.cut(y_data, bins=bins) y_data = pd.cut(y_data, bins=bins)
labels = [str(v) for v in y_data.sort_values().unique()] labels = [str(v) for v in y_data.sort_values().unique()]
y_data = y_data.cat.codes y_data = y_data.cat.codes
elif task == "density":
y_data = data.loc[X_data.index, "darts_rts_density"]
n_categories = 5
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
# Change the first interval to start at 0
bins = pd.IntervalIndex.from_tuples([(0.0, interval.right) for interval in bins])
y_data = pd.cut(y_data, bins=bins)
labels = [str(v) for v in y_data.sort_values().unique()]
y_data = y_data.cat.codes
else:
raise ValueError(f"Unknown task: {task}")
return data, X_data, y_data, labels return data, X_data, y_data, labels
@ -71,7 +85,8 @@ def random_cv(
level: int, level: int,
n_iter: int = 2000, n_iter: int = 2000,
robust: bool = False, robust: bool = False,
task: Literal["binary", "multi"] = "binary", task: Literal["binary", "count", "density"] = "binary",
model: Literal["espa", "xgboost", "rf", "knn"] = "espa",
): ):
"""Perform random cross-validation on the training dataset. """Perform random cross-validation on the training dataset.
@ -80,38 +95,74 @@ def random_cv(
level (int): The grid level to use. level (int): The grid level to use.
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
robust (bool, optional): Whether to use robust training. Defaults to False. robust (bool, optional): Whether to use robust training. Defaults to False.
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary". task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
""" """
_, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task) _, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
print(f"Using {task}-class classification with {len(labels)} classes: {labels}") print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
print(f"{y_data.describe()=}") print(f"{y_data.describe()=}")
X = X_data.to_numpy(dtype="float32") X = X_data.to_numpy(dtype="float64")
y = y_data.to_numpy(dtype="int8") y = y_data.to_numpy(dtype="int8")
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0) X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
print(f"{X.shape=}, {y.shape=}") print(f"{X.shape=}, {y.shape=}")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}") print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
param_grid = { if model == "espa":
"eps_cl": loguniform(1e-3, 1e7), clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
"eps_e": loguniform(1e-3, 1e7), if task == "binary":
"initial_K": randint(20, 400), param_grid = {
} "eps_cl": loguniform(1e-4, 1e1),
"eps_e": loguniform(1e1, 1e7),
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust) "initial_K": randint(20, 400),
}
else:
param_grid = {
"eps_cl": loguniform(1e-11, 1e-6),
"eps_e": loguniform(1e4, 1e8),
"initial_K": randint(400, 800),
}
elif model == "xgboost":
param_grid = {
"learning_rate": loguniform(1e-4, 1e-1),
"max_depth": randint(3, 15),
"n_estimators": randint(100, 1000),
"subsample": loguniform(0.5, 1.0),
"colsample_bytree": loguniform(0.5, 1.0),
}
clf = XGBClassifier(
objective="multi:softprob" if task != "binary" else "binary:logistic",
eval_metric="mlogloss" if task != "binary" else "logloss",
random_state=42,
tree_method="gpu_hist",
device="cuda",
)
elif model == "rf":
param_grid = {
"max_depth": randint(5, 50),
"n_estimators": randint(50, 500),
}
clf = RandomForestClassifier(random_state=42)
elif model == "knn":
param_grid = {
"n_neighbors": randint(3, 15),
"weights": ["uniform", "distance"],
"algorithm": ["brute", "kd_tree", "ball_tree"],
}
clf = KNeighborsClassifier(random_state=42)
else:
raise ValueError(f"Unknown model: {model}")
cv = KFold(n_splits=5, shuffle=True, random_state=42) cv = KFold(n_splits=5, shuffle=True, random_state=42)
if task == "binary": if task == "binary":
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
else: else:
metrics = [ metrics = [
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted"
"f1_macro", "f1_macro",
"f1_weighted", "f1_weighted",
"precision_macro", "precision_macro",
"precision_weighted", "precision_weighted",
"recall_macro", "recall_macro",
"recall_weighted",
"jaccard_micro", "jaccard_micro",
"jaccard_macro", "jaccard_macro",
"jaccard_weighted", "jaccard_weighted",
@ -120,17 +171,17 @@ def random_cv(
clf, clf,
param_grid, param_grid,
n_iter=n_iter, n_iter=n_iter,
n_jobs=16, n_jobs=8,
cv=cv, cv=cv,
random_state=42, random_state=42,
verbose=3, verbose=10,
scoring=metrics, scoring=metrics,
refit="f1" if task == "binary" else "f1_weighted", refit="f1" if task == "binary" else "f1_weighted",
) )
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...") print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"): with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
search.fit(X_train, y_train, max_iter=100) search.fit(X_train, y_train, max_iter=300)
print("Best parameters combination found:") print("Best parameters combination found:")
best_parameters = search.best_estimator_.get_params() best_parameters = search.best_estimator_.get_params()
@ -144,29 +195,33 @@ def random_cv(
results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task) results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
# Store the search settings # Store the search settings
# First convert the param_grid distributions to a serializable format
param_grid_serializable = {}
for key, dist in param_grid.items():
if isinstance(dist, loguniform):
param_grid_serializable[key] = {
"distribution": "loguniform",
"low": dist.a,
"high": dist.b,
}
elif isinstance(dist, randint):
param_grid_serializable[key] = {
"distribution": "randint",
"low": dist.a,
"high": dist.b,
}
elif isinstance(dist, list):
param_grid_serializable[key] = dist
else:
raise ValueError(f"Unknown distribution type for {key}: {type(dist)}")
settings = { settings = {
"task": task, "task": task,
"model": model,
"grid": grid, "grid": grid,
"level": level, "level": level,
"random_state": 42, "random_state": 42,
"n_iter": n_iter, "n_iter": n_iter,
"param_grid": { "param_grid": param_grid_serializable,
"eps_cl": {
"distribution": "loguniform",
"low": param_grid["eps_cl"].a,
"high": param_grid["eps_cl"].b,
},
"eps_e": {
"distribution": "loguniform",
"low": param_grid["eps_e"].a,
"high": param_grid["eps_e"].b,
},
"initial_K": {
"distribution": "randint",
"low": param_grid["initial_K"].a,
"high": param_grid["initial_K"].b,
},
},
"cv_splits": cv.get_n_splits(), "cv_splits": cv.get_n_splits(),
"metrics": metrics, "metrics": metrics,
"classes": labels, "classes": labels,
@ -195,46 +250,47 @@ def random_cv(
results.to_parquet(results_file) results.to_parquet(results_file)
# Get the inner state of the best estimator # Get the inner state of the best estimator
best_estimator = search.best_estimator_ if model == "espa":
# Annotate the state with xarray metadata best_estimator = search.best_estimator_
features = X_data.columns.tolist() # Annotate the state with xarray metadata
boxes = list(range(best_estimator.K_)) features = X_data.columns.tolist()
box_centers = xr.DataArray( boxes = list(range(best_estimator.K_))
best_estimator.S_.cpu().numpy(), box_centers = xr.DataArray(
dims=["feature", "box"], best_estimator.S_.cpu().numpy(),
coords={"feature": features, "box": boxes}, dims=["feature", "box"],
name="box_centers", coords={"feature": features, "box": boxes},
attrs={"description": "Centers of the boxes in feature space."}, name="box_centers",
) attrs={"description": "Centers of the boxes in feature space."},
box_assignments = xr.DataArray( )
best_estimator.Lambda_.cpu().numpy(), box_assignments = xr.DataArray(
dims=["class", "box"], best_estimator.Lambda_.cpu().numpy(),
coords={"class": labels, "box": boxes}, dims=["class", "box"],
name="box_assignments", coords={"class": labels, "box": boxes},
attrs={"description": "Assignments of samples to boxes."}, name="box_assignments",
) attrs={"description": "Assignments of samples to boxes."},
feature_weights = xr.DataArray( )
best_estimator.W_.cpu().numpy(), feature_weights = xr.DataArray(
dims=["feature"], best_estimator.W_.cpu().numpy(),
coords={"feature": features}, dims=["feature"],
name="feature_weights", coords={"feature": features},
attrs={"description": "Feature weights for each box."}, name="feature_weights",
) attrs={"description": "Feature weights for each box."},
state = xr.Dataset( )
{ state = xr.Dataset(
"box_centers": box_centers, {
"box_assignments": box_assignments, "box_centers": box_centers,
"feature_weights": feature_weights, "box_assignments": box_assignments,
}, "feature_weights": feature_weights,
attrs={ },
"description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.", attrs={
"grid": grid, "description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.",
"level": level, "grid": grid,
}, "level": level,
) },
state_file = results_dir / "best_estimator_state.nc" )
print(f"Storing best estimator state to {state_file}") state_file = results_dir / "best_estimator_state.nc"
state.to_netcdf(state_file, engine="h5netcdf") print(f"Storing best estimator state to {state_file}")
state.to_netcdf(state_file, engine="h5netcdf")
# Predict probabilities for all cells # Predict probabilities for all cells
print("Predicting probabilities for all cells...") print("Predicting probabilities for all cells...")