Refactor spatial aggregations
This commit is contained in:
parent
1a71883999
commit
e5382670ec
5 changed files with 457 additions and 240 deletions
254
src/entropice/aggregators.py
Normal file
254
src/entropice/aggregators.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
import cupy_xarray
|
||||
import geopandas as gpd
|
||||
import numpy as np
|
||||
import odc.geo.geobox
|
||||
import shapely
|
||||
import shapely.ops
|
||||
import xarray as xr
|
||||
import xdggs
|
||||
import xvec
|
||||
from rich.progress import track
|
||||
from shapely.geometry import LineString, Polygon
|
||||
from stopuhr import stopwatch
|
||||
from xdggs.healpix import HealpixInfo
|
||||
|
||||
from entropice import grids
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Aggregations:
|
||||
mean: bool = True
|
||||
sum: bool = False
|
||||
std: bool = False
|
||||
min: bool = False
|
||||
max: bool = False
|
||||
quantiles: list[float] = field(default_factory=lambda: [])
|
||||
|
||||
def varnames(self, vars: list[str] | str) -> list[str]:
|
||||
if isinstance(vars, str):
|
||||
vars = [vars]
|
||||
agg_vars = []
|
||||
for var in vars:
|
||||
if self.mean:
|
||||
agg_vars.append(f"{var}_mean")
|
||||
if self.sum:
|
||||
agg_vars.append(f"{var}_sum")
|
||||
if self.std:
|
||||
agg_vars.append(f"{var}_std")
|
||||
if self.min:
|
||||
agg_vars.append(f"{var}_min")
|
||||
if self.max:
|
||||
agg_vars.append(f"{var}_max")
|
||||
for q in self.quantiles:
|
||||
q_int = int(q * 100)
|
||||
agg_vars.append(f"{var}_p{q_int}")
|
||||
return agg_vars
|
||||
|
||||
|
||||
def _crosses_antimeridian(geom: Polygon) -> bool:
|
||||
coords = shapely.get_coordinates(geom)
|
||||
crosses_any_meridian = (coords[:, 0] > 0).any() and (coords[:, 0] < 0).any()
|
||||
return crosses_any_meridian and abs(coords[:, 0]).max() > 90
|
||||
|
||||
|
||||
def _split_antimeridian_cell(geom: Polygon) -> list[Polygon]:
|
||||
# Assumes that it is a antimeridian hex
|
||||
coords = shapely.get_coordinates(geom)
|
||||
for i in range(coords.shape[0]):
|
||||
if coords[i, 0] < 0:
|
||||
coords[i, 0] += 360
|
||||
geom = Polygon(coords)
|
||||
antimeridian = LineString([[180, -90], [180, 90]])
|
||||
polys = shapely.ops.split(geom, antimeridian)
|
||||
return list(polys.geoms)
|
||||
|
||||
|
||||
def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
|
||||
enclosing = geobox.enclosing(geom)
|
||||
x, y = enclosing.shape
|
||||
if x <= 1 or y <= 1:
|
||||
return False
|
||||
roi: tuple[slice, slice] = geobox.overlap_roi(enclosing)
|
||||
roix, roiy = roi
|
||||
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
|
||||
|
||||
|
||||
@stopwatch("Getting corrected geometries", log=False)
|
||||
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]:
|
||||
geom, gbox, crs = inp
|
||||
# cell.geometry is a shapely Polygon
|
||||
if not _crosses_antimeridian(geom):
|
||||
geoms = [geom]
|
||||
# Split geometry in case it crossed antimeridian
|
||||
else:
|
||||
geoms = _split_antimeridian_cell(geom)
|
||||
|
||||
geoms = [odc.geo.Geometry(g, crs=crs) for g in geoms]
|
||||
geoms = list(filter(lambda g: _check_geom(gbox, g), geoms))
|
||||
return geoms
|
||||
|
||||
|
||||
@stopwatch("Correcting geometries")
|
||||
def get_corrected_geometries(grid_gdf: gpd.GeoDataFrame, gbox: odc.geo.geobox.GeoBox):
|
||||
"""Get corrected geometries for antimeridian-crossing polygons.
|
||||
|
||||
Args:
|
||||
grid_gdf (gpd.GeoDataFrame): Grid GeoDataFrame.
|
||||
gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference.
|
||||
|
||||
Returns:
|
||||
list[list[odc.geo.Geometry]]: List of corrected, georeferenced geometries.
|
||||
|
||||
"""
|
||||
with ProcessPoolExecutor(max_workers=20) as executor:
|
||||
cell_geometries = list(
|
||||
executor.map(_get_corrected_geoms, [(row.geometry, gbox, grid_gdf.crs) for _, row in grid_gdf.iterrows()])
|
||||
)
|
||||
return cell_geometries
|
||||
|
||||
|
||||
@stopwatch("Aggregating cell data", log=False)
|
||||
def _agg_cell_data(flattened: xr.Dataset, aggregations: _Aggregations):
|
||||
cell_data = {}
|
||||
for var in flattened.data_vars:
|
||||
if aggregations.mean:
|
||||
cell_data[f"{var}_mean"] = flattened[var].mean(dim="z", skipna=True)
|
||||
if aggregations.sum:
|
||||
cell_data[f"{var}_sum"] = flattened[var].sum(dim="z", skipna=True)
|
||||
if aggregations.std:
|
||||
cell_data[f"{var}_std"] = flattened[var].std(dim="z", skipna=True)
|
||||
if aggregations.min:
|
||||
cell_data[f"{var}_min"] = flattened[var].min(dim="z", skipna=True)
|
||||
if aggregations.max:
|
||||
cell_data[f"{var}_max"] = flattened[var].max(dim="z", skipna=True)
|
||||
if len(aggregations.quantiles) > 0:
|
||||
quantile_values = flattened[var].quantile(
|
||||
q=aggregations.quantiles,
|
||||
dim="z",
|
||||
skipna=True,
|
||||
)
|
||||
for q, qv in zip(aggregations.quantiles, quantile_values):
|
||||
q_int = int(q * 100)
|
||||
cell_data[f"{var}_p{q_int}"] = qv
|
||||
# Transform to numpy arrays
|
||||
for key in cell_data:
|
||||
cell_data[key] = cell_data[key].cupy.as_numpy().values
|
||||
return cell_data
|
||||
|
||||
|
||||
@stopwatch("Extracting cell data", log=False)
|
||||
def _extract_cell_data(ds: xr.Dataset, geom: odc.geo.Geometry, aggregations: _Aggregations):
|
||||
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
|
||||
cropped: xr.Dataset = ds.odc.crop(geom).drop_vars("spatial_ref")
|
||||
flattened = cropped.stack(z=spatdims)
|
||||
if flattened.z.size > 3000:
|
||||
flattened = flattened.cupy.as_cupy()
|
||||
cell_data = _agg_cell_data(flattened, aggregations)
|
||||
return cell_data
|
||||
|
||||
# with np.errstate(divide="ignore", invalid="ignore"):
|
||||
# cell_data = cropped.mean(spatdims)
|
||||
# return {var: cell_data[var].values for var in cell_data.data_vars}
|
||||
|
||||
|
||||
@stopwatch("Extracting split cell data", log=False)
|
||||
def _extract_split_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry], aggregations: _Aggregations):
|
||||
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
|
||||
cropped: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(ds, geoms)]
|
||||
flattened = xr.concat([c.stack(z=spatdims) for c in cropped], dim="z")
|
||||
cell_data = _agg_cell_data(flattened, aggregations)
|
||||
return cell_data
|
||||
|
||||
# partial_counts = [part.notnull().sum(dim=spatdims) for part in parts]
|
||||
# with np.errstate(divide="ignore", invalid="ignore"):
|
||||
# partial_means = [part.sum(dim=spatdims) for part in parts]
|
||||
# n = xr.concat(partial_counts, dim="part").sum("part")
|
||||
# cell_data = xr.concat(partial_means, dim="part").sum("part") / n
|
||||
# return {var: cell_data[var].values for var in cell_data.data_vars}
|
||||
|
||||
|
||||
@stopwatch("Aligning data with grid")
|
||||
def _align_data(
|
||||
grid_gdf: gpd.GeoDataFrame,
|
||||
unaligned: xr.Dataset,
|
||||
aggregations: _Aggregations,
|
||||
) -> dict[str, np.ndarray]:
|
||||
# Persist the dataset, as all the operations here MUST NOT be lazy
|
||||
unaligned = unaligned.load()
|
||||
|
||||
cell_geometries = get_corrected_geometries(grid_gdf, unaligned.odc.geobox)
|
||||
other_dims_shape = tuple(
|
||||
[unaligned.sizes[dim] for dim in unaligned.dims if dim not in ["y", "x", "latitude", "longitude"]]
|
||||
)
|
||||
data_shape = (len(cell_geometries), *other_dims_shape)
|
||||
data = {var: np.full(data_shape, np.nan, dtype=np.float32) for var in aggregations.varnames(unaligned.data_vars)}
|
||||
with ProcessPoolExecutor(max_workers=10) as executor:
|
||||
futures = {}
|
||||
for i, geoms in track(
|
||||
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
|
||||
):
|
||||
if len(geoms) == 0:
|
||||
continue
|
||||
elif len(geoms) == 1:
|
||||
geom = geoms[0]
|
||||
# Reduce the amount of data needed to be sent to the worker
|
||||
# Since we dont mask the data, only isel operations are done here
|
||||
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
|
||||
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
|
||||
fut = executor.submit(_extract_cell_data, unaligned_subset, geom, aggregations)
|
||||
else:
|
||||
# Same as above but for multiple parts
|
||||
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
|
||||
fut = executor.submit(_extract_split_cell_data, unaligned_subsets, geoms, aggregations)
|
||||
futures[fut] = i
|
||||
|
||||
for future in track(
|
||||
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
|
||||
):
|
||||
i = futures[future]
|
||||
cell_data = future.result()
|
||||
for var in unaligned.data_vars:
|
||||
data[var][i, :] = cell_data[var]
|
||||
return data
|
||||
|
||||
|
||||
def aggregate_raster_into_grid(
|
||||
raster: xr.Dataset,
|
||||
grid_gdf: gpd.GeoDataFrame,
|
||||
aggregations: _Aggregations,
|
||||
grid: Literal["hex", "healpix"],
|
||||
level: int,
|
||||
):
|
||||
aligned = _align_data(grid_gdf, raster, aggregations)
|
||||
|
||||
cell_ids = grids.get_cell_ids(grid, level)
|
||||
|
||||
dims = ["cell_ids"]
|
||||
coords = {"cell_ids": cell_ids}
|
||||
for dim in raster.dims:
|
||||
if dim not in ["y", "x", "latitude", "longitude"]:
|
||||
dims.append(dim)
|
||||
coords[dim] = raster.coords[dim]
|
||||
|
||||
data_vars = {var: (dims, values) for var, values in aligned.items()}
|
||||
ongrid = xr.Dataset(
|
||||
data_vars,
|
||||
coords=coords,
|
||||
)
|
||||
gridinfo = {
|
||||
"grid_name": "h3" if grid == "hex" else grid,
|
||||
"level": level,
|
||||
}
|
||||
if grid == "healpix":
|
||||
gridinfo["indexing_scheme"] = "nested"
|
||||
ongrid.cell_ids.attrs = gridinfo
|
||||
for var in raster.data_vars:
|
||||
for v in aggregations.varnames(var):
|
||||
ongrid[v].attrs = raster[var].attrs
|
||||
|
||||
ongrid = xdggs.decode(ongrid)
|
||||
return ongrid
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Literal
|
||||
|
||||
import cupy as cp
|
||||
import cupy_xarray
|
||||
|
|
@ -12,6 +13,7 @@ import icechunk.xarray
|
|||
import numpy as np
|
||||
import smart_geocubes
|
||||
import xarray as xr
|
||||
import xdggs
|
||||
import xrspatial
|
||||
import zarr
|
||||
from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt
|
||||
|
|
@ -22,7 +24,9 @@ from xrspatial.curvature import _run_cupy as curvature_cupy
|
|||
from xrspatial.slope import _run_cupy as slope_cupy
|
||||
from zarr.codecs import BloscCodec
|
||||
|
||||
from entropice.paths import arcticdem_store
|
||||
from entropice import codecs, grids, watermask
|
||||
from entropice.aggregators import _Aggregations, aggregate_raster_into_grid
|
||||
from entropice.paths import get_arcticdem_stores
|
||||
|
||||
traceback.install(show_locals=True, suppress=[cyclopts])
|
||||
pretty.install()
|
||||
|
|
@ -34,6 +38,7 @@ cli = cyclopts.App(name="arcticdem")
|
|||
|
||||
@cli.command()
|
||||
def download():
|
||||
arcticdem_store = get_arcticdem_stores()
|
||||
adem = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||
adem.create(exists_ok=True)
|
||||
with stopwatch("Download ArcticDEM data"):
|
||||
|
|
@ -172,7 +177,7 @@ def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) ->
|
|||
tpi_large = tpi_cupy(chunk, large_kernels)
|
||||
# Slope
|
||||
slope = slope_cupy(chunk, res, res)
|
||||
# Aspect
|
||||
# Aspect & correction
|
||||
aspect = aspect_cupy(chunk)
|
||||
xx, yy = _get_xy_chunk(chunk, x, y, block_info)
|
||||
aspect_correction = cp.arctan2(yy, xx) * (180 / cp.pi) + 90
|
||||
|
|
@ -235,12 +240,6 @@ def _enrich(terrain: xr.DataArray):
|
|||
enriched_da = dask.array.map_overlap(
|
||||
_enrich_chunk,
|
||||
terrain.data,
|
||||
# dask.array.from_array(terrain.y.data.reshape(-1, 1), chunks=(terrain.data.chunks[0][0], 3600)).repeat(
|
||||
# terrain.x.size, axis=1
|
||||
# ),
|
||||
# dask.array.from_array(terrain.x.data.reshape(1, -1), chunks=(3600, terrain.data.chunks[1][0])).repeat(
|
||||
# terrain.y.size, axis=0
|
||||
# ),
|
||||
x=terrain.x.to_numpy(),
|
||||
y=terrain.y.to_numpy(),
|
||||
depth=15, # large_kernels.size_px
|
||||
|
|
@ -273,14 +272,13 @@ def enrich():
|
|||
print(client)
|
||||
print(client.dashboard_link)
|
||||
|
||||
arcticdem_store = get_arcticdem_stores()
|
||||
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||
|
||||
# Garbage collect from previous runs
|
||||
accessor.repo.garbage_collect(datetime.datetime.now(datetime.UTC))
|
||||
|
||||
adem = accessor.open_xarray()
|
||||
# session = adem.repo.readonly_session("main")
|
||||
# adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref")
|
||||
del adem.y.attrs["_FillValue"]
|
||||
del adem.x.attrs["_FillValue"]
|
||||
enriched, new_features = _enrich(adem.dem)
|
||||
|
|
@ -289,10 +287,9 @@ def enrich():
|
|||
adem[feature] = enriched.sel(feature=feature)
|
||||
print(adem[new_features])
|
||||
# subset = adem[new_features].isel(x=slice(190800, 220000), y=slice(61200, 100000))
|
||||
encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features}
|
||||
# subset.to_zarr("test2.zarr", mode="w", encoding=encodings)
|
||||
|
||||
session = accessor.repo.writable_session("main")
|
||||
encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features}
|
||||
icechunk.xarray.to_icechunk(
|
||||
adem[new_features],
|
||||
session,
|
||||
|
|
@ -304,5 +301,32 @@ def enrich():
|
|||
print("Enrichment complete.")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def aggregate(grid: Literal["hex", "healpix"], level: int):
|
||||
with stopwatch(f"Loading {grid} grid at level {level}"):
|
||||
grid_gdf = grids.open(grid, level)
|
||||
# ? Mask out water, since we don't want to aggregate over oceans
|
||||
grid_gdf = watermask.clip_grid(grid_gdf)
|
||||
|
||||
arcticdem_store = get_arcticdem_stores()
|
||||
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||
adem = accessor.open_xarray()
|
||||
assert {"x", "y"} == set(adem.dims)
|
||||
assert adem.odc.crs == "EPSG:3413"
|
||||
|
||||
aggregations = _Aggregations(
|
||||
mean=True,
|
||||
sum=False,
|
||||
std=True,
|
||||
min=True,
|
||||
max=True,
|
||||
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
|
||||
)
|
||||
aggregated = aggregate_raster_into_grid(adem, grid_gdf, aggregations, grid, level)
|
||||
store = get_arcticdem_stores(grid, level)
|
||||
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
|
||||
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# ruff: noqa: PD003, PD011
|
||||
# ruff: noqa: PD003
|
||||
"""Download and preprocess ERA5 data.
|
||||
|
||||
Variables of Interest:
|
||||
|
|
@ -79,7 +79,6 @@ Date: June to October 2025
|
|||
|
||||
import cProfile
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import Literal
|
||||
|
||||
import cyclopts
|
||||
|
|
@ -88,18 +87,15 @@ import numpy as np
|
|||
import odc.geo
|
||||
import odc.geo.xr
|
||||
import pandas as pd
|
||||
import shapely
|
||||
import shapely.ops
|
||||
import ultraplot as uplt
|
||||
import xarray as xr
|
||||
import xdggs
|
||||
import xvec
|
||||
from rich import pretty, print, traceback
|
||||
from rich.progress import track
|
||||
from shapely.geometry import LineString, Polygon
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice import codecs, grids, watermask
|
||||
from entropice.aggregators import _Aggregations, aggregate_raster_into_grid
|
||||
from entropice.paths import FIGURES_DIR, get_era5_stores
|
||||
from entropice.xvec import to_xvec
|
||||
|
||||
|
|
@ -647,145 +643,10 @@ def viz(
|
|||
# ===========================
|
||||
|
||||
|
||||
def _crosses_antimeridian(geom: Polygon) -> bool:
|
||||
coords = shapely.get_coordinates(geom)
|
||||
crosses_any_meridian = (coords[:, 0] > 0).any() and (coords[:, 0] < 0).any()
|
||||
return crosses_any_meridian and abs(coords[:, 0]).max() > 90
|
||||
|
||||
|
||||
def _split_antimeridian_cell(geom: Polygon) -> list[Polygon]:
|
||||
# Assumes that it is a antimeridian hex
|
||||
coords = shapely.get_coordinates(geom)
|
||||
for i in range(coords.shape[0]):
|
||||
if coords[i, 0] < 0:
|
||||
coords[i, 0] += 360
|
||||
geom = Polygon(coords)
|
||||
antimeridian = LineString([[180, -90], [180, 90]])
|
||||
polys = shapely.ops.split(geom, antimeridian)
|
||||
return list(polys.geoms)
|
||||
|
||||
|
||||
def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
|
||||
enclosing = geobox.enclosing(geom)
|
||||
x, y = enclosing.shape
|
||||
if x <= 1 or y <= 1:
|
||||
return False
|
||||
roi: tuple[slice, slice] = geobox.overlap_roi(enclosing)
|
||||
roix, roiy = roi
|
||||
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
|
||||
|
||||
|
||||
@stopwatch("Getting corrected geometries", log=False)
|
||||
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox]) -> list[odc.geo.Geometry]:
|
||||
"""Get corrected geometries for antimeridian-crossing polygons.
|
||||
|
||||
Args:
|
||||
geom (Polygon): Input polygon geometry.
|
||||
gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference.
|
||||
|
||||
Returns:
|
||||
list[odc.geo.Geometry]: List of corrected, georeferenced geometries.
|
||||
|
||||
"""
|
||||
geom, gbox = inp
|
||||
# cell.geometry is a shapely Polygon
|
||||
if not _crosses_antimeridian(geom):
|
||||
geoms = [geom]
|
||||
# Split geometry in case it crossed antimeridian
|
||||
else:
|
||||
geoms = _split_antimeridian_cell(geom)
|
||||
|
||||
geoms = [odc.geo.Geometry(g, crs="epsg:4326") for g in geoms]
|
||||
geoms = list(filter(lambda g: _check_geom(gbox, g), geoms))
|
||||
return geoms
|
||||
|
||||
|
||||
def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
|
||||
return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude")
|
||||
|
||||
|
||||
@stopwatch("Extracting cell data", log=False)
|
||||
def _extract_cell_data(ds, geom):
|
||||
cropped = ds.odc.crop(geom).drop_vars("spatial_ref")
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
cell_data = cropped.mean(["latitude", "longitude"])
|
||||
return {var: cell_data[var].values for var in cell_data.data_vars}
|
||||
|
||||
|
||||
@stopwatch("Extracting split cell data", log=False)
|
||||
def _extract_split_cell_data(dss: xr.Dataset, geoms):
|
||||
parts: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(dss, geoms)]
|
||||
partial_counts = [part.notnull().sum(dim=["latitude", "longitude"]) for part in parts]
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
partial_means = [part.sum(["latitude", "longitude"]) for part in parts]
|
||||
n = xr.concat(partial_counts, dim="part").sum("part")
|
||||
cell_data = xr.concat(partial_means, dim="part").sum("part") / n
|
||||
return {var: cell_data[var].values for var in cell_data.data_vars}
|
||||
|
||||
|
||||
@stopwatch("Aligning data")
|
||||
def _align_data(cell_geometries: list[list[odc.geo.Geometry]], unaligned: xr.Dataset) -> dict[str, np.ndarray]:
|
||||
# Persist the dataset, as all the operations here MUST NOT be lazy
|
||||
unaligned = unaligned.load()
|
||||
|
||||
data = {
|
||||
var: np.full((len(cell_geometries), len(unaligned.time)), np.nan, dtype=np.float32)
|
||||
for var in unaligned.data_vars
|
||||
}
|
||||
with ProcessPoolExecutor(max_workers=10) as executor:
|
||||
futures = {}
|
||||
for i, geoms in track(
|
||||
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
|
||||
):
|
||||
if len(geoms) == 0:
|
||||
continue
|
||||
elif len(geoms) == 1:
|
||||
geom = geoms[0]
|
||||
# Reduce the amount of data needed to be sent to the worker
|
||||
# Since we dont mask the data, only isel operations are done here
|
||||
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
|
||||
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
|
||||
futures[executor.submit(_extract_cell_data, unaligned_subset, geom)] = i
|
||||
else:
|
||||
# Same as above but for multiple parts
|
||||
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
|
||||
futures[executor.submit(_extract_split_cell_data, unaligned_subsets, geoms)] = i
|
||||
|
||||
for future in track(
|
||||
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
|
||||
):
|
||||
i = futures[future]
|
||||
cell_data = future.result()
|
||||
for var in unaligned.data_vars:
|
||||
data[var][i, :] = cell_data[var]
|
||||
return data
|
||||
|
||||
|
||||
@stopwatch("Creating aligned dataset", log=False)
|
||||
def _create_aligned(
|
||||
ds: xr.Dataset, data: dict[str, np.ndarray], grid: Literal["hex", "healpix"], level: int
|
||||
) -> xr.Dataset:
|
||||
cell_ids = grids.get_cell_ids(grid, level)
|
||||
data_vars = {var: (["cell_ids", "time"], values) for var, values in data.items()}
|
||||
aligned = xr.Dataset(
|
||||
data_vars,
|
||||
coords={"cell_ids": cell_ids, "time": ds.time},
|
||||
)
|
||||
gridinfo = {
|
||||
"grid_name": "h3" if grid == "hex" else grid,
|
||||
"level": level,
|
||||
}
|
||||
if grid == "healpix":
|
||||
gridinfo["indexing_scheme"] = "nested"
|
||||
aligned.cell_ids.attrs = gridinfo
|
||||
for var in ds.data_vars:
|
||||
aligned[var].attrs = ds[var].attrs
|
||||
|
||||
aligned = aligned.chunk({"cell_ids": min(len(aligned.cell_ids), 10000), "time": len(aligned.time)})
|
||||
aligned = xdggs.decode(aligned)
|
||||
return aligned
|
||||
|
||||
|
||||
@cli.command
|
||||
def spatial_agg(
|
||||
grid: Literal["hex", "healpix"],
|
||||
|
|
@ -816,18 +677,28 @@ def spatial_agg(
|
|||
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
|
||||
unaligned = _correct_longs(unaligned)
|
||||
|
||||
with stopwatch("Precomputing cell geometries"):
|
||||
with ProcessPoolExecutor(max_workers=20) as executor:
|
||||
cell_geometries = list(
|
||||
executor.map(
|
||||
_get_corrected_geoms,
|
||||
[(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()],
|
||||
)
|
||||
)
|
||||
# with stopwatch("Precomputing cell geometries"):
|
||||
# with ProcessPoolExecutor(max_workers=20) as executor:
|
||||
# cell_geometries = list(
|
||||
# executor.map(
|
||||
# _get_corrected_geoms,
|
||||
# [(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()],
|
||||
# )
|
||||
# )
|
||||
# data = _align_data(cell_geometries, unaligned)
|
||||
# aggregated = _create_aligned(unaligned, data, grid, level)
|
||||
|
||||
data = _align_data(cell_geometries, unaligned)
|
||||
aggregations = _Aggregations(
|
||||
mean=True,
|
||||
sum=False,
|
||||
std=True,
|
||||
min=True,
|
||||
max=True,
|
||||
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
|
||||
)
|
||||
aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level)
|
||||
|
||||
aggregated = _create_aligned(unaligned, data, grid, level)
|
||||
aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})
|
||||
store = get_era5_stores(agg, grid, level)
|
||||
with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"):
|
||||
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ GRIDS_DIR = DATA_DIR / "grids"
|
|||
FIGURES_DIR = Path("figures")
|
||||
DARTS_DIR = DATA_DIR / "darts"
|
||||
ERA5_DIR = DATA_DIR / "era5"
|
||||
ARCTICDEM_DIR = DATA_DIR / "arcticdem"
|
||||
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
||||
WATERMASK_DIR = DATA_DIR / "watermask"
|
||||
TRAINING_DIR = DATA_DIR / "training"
|
||||
|
|
@ -22,12 +23,12 @@ GRIDS_DIR.mkdir(parents=True, exist_ok=True)
|
|||
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
DARTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ERA5_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True)
|
||||
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
|
||||
TRAINING_DIR.mkdir(parents=True, exist_ok=True)
|
||||
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
arcticdem_store = DATA_DIR / "arcticdem32m.icechunk.zarr"
|
||||
|
||||
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
||||
|
||||
|
|
@ -83,6 +84,17 @@ def get_era5_stores(
|
|||
return aligned_path
|
||||
|
||||
|
||||
def get_arcticdem_stores(
|
||||
grid: Literal["hex", "healpix"] | None = None,
|
||||
level: int | None = None,
|
||||
):
|
||||
if grid is None or level is None:
|
||||
return DATA_DIR / "arcticdem32m.icechunk.zarr"
|
||||
gridname = _get_gridname(grid, level)
|
||||
aligned_path = ARCTICDEM_DIR / f"{gridname}_dem.zarr"
|
||||
return aligned_path
|
||||
|
||||
|
||||
def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# ruff: noqa: N806
|
||||
"""Training dataset preparation and model training."""
|
||||
"""Training of classification models training."""
|
||||
|
||||
import pickle
|
||||
from typing import Literal
|
||||
|
|
@ -10,12 +10,15 @@ import pandas as pd
|
|||
import toml
|
||||
import torch
|
||||
import xarray as xr
|
||||
from cuml.ensemble import RandomForestClassifier
|
||||
from cuml.neighbors import KNeighborsClassifier
|
||||
from entropy import ESPAClassifier
|
||||
from rich import pretty, traceback
|
||||
from scipy.stats import loguniform, randint
|
||||
from sklearn import set_config
|
||||
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
||||
from stopuhr import stopwatch
|
||||
from xgboost import XGBClassifier
|
||||
|
||||
from entropice.inference import predict_proba
|
||||
from entropice.paths import (
|
||||
|
|
@ -29,13 +32,13 @@ pretty.install()
|
|||
set_config(array_api_dispatch=True)
|
||||
|
||||
|
||||
def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "multi"] = "binary"):
|
||||
def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "count", "density"] = "binary"):
|
||||
"""Create X and y data from the training dataset.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
|
||||
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
||||
|
||||
Returns:
|
||||
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
|
||||
|
|
@ -51,7 +54,7 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b
|
|||
if task == "binary":
|
||||
labels = ["No RTS", "RTS"]
|
||||
y_data = data.loc[X_data.index, "darts_has_rts"]
|
||||
else:
|
||||
elif task == "count":
|
||||
# Put into n categories (log scaled)
|
||||
y_data = data.loc[X_data.index, "darts_rts_count"]
|
||||
n_categories = 5
|
||||
|
|
@ -63,6 +66,17 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b
|
|||
y_data = pd.cut(y_data, bins=bins)
|
||||
labels = [str(v) for v in y_data.sort_values().unique()]
|
||||
y_data = y_data.cat.codes
|
||||
elif task == "density":
|
||||
y_data = data.loc[X_data.index, "darts_rts_density"]
|
||||
n_categories = 5
|
||||
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
||||
# Change the first interval to start at 0
|
||||
bins = pd.IntervalIndex.from_tuples([(0.0, interval.right) for interval in bins])
|
||||
y_data = pd.cut(y_data, bins=bins)
|
||||
labels = [str(v) for v in y_data.sort_values().unique()]
|
||||
y_data = y_data.cat.codes
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {task}")
|
||||
return data, X_data, y_data, labels
|
||||
|
||||
|
||||
|
|
@ -71,7 +85,8 @@ def random_cv(
|
|||
level: int,
|
||||
n_iter: int = 2000,
|
||||
robust: bool = False,
|
||||
task: Literal["binary", "multi"] = "binary",
|
||||
task: Literal["binary", "count", "density"] = "binary",
|
||||
model: Literal["espa", "xgboost", "rf", "knn"] = "espa",
|
||||
):
|
||||
"""Perform random cross-validation on the training dataset.
|
||||
|
||||
|
|
@ -80,38 +95,74 @@ def random_cv(
|
|||
level (int): The grid level to use.
|
||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
|
||||
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
||||
|
||||
"""
|
||||
_, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
|
||||
print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
|
||||
print(f"{y_data.describe()=}")
|
||||
X = X_data.to_numpy(dtype="float32")
|
||||
X = X_data.to_numpy(dtype="float64")
|
||||
y = y_data.to_numpy(dtype="int8")
|
||||
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
||||
print(f"{X.shape=}, {y.shape=}")
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
||||
|
||||
param_grid = {
|
||||
"eps_cl": loguniform(1e-3, 1e7),
|
||||
"eps_e": loguniform(1e-3, 1e7),
|
||||
"initial_K": randint(20, 400),
|
||||
}
|
||||
|
||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
||||
if model == "espa":
|
||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
||||
if task == "binary":
|
||||
param_grid = {
|
||||
"eps_cl": loguniform(1e-4, 1e1),
|
||||
"eps_e": loguniform(1e1, 1e7),
|
||||
"initial_K": randint(20, 400),
|
||||
}
|
||||
else:
|
||||
param_grid = {
|
||||
"eps_cl": loguniform(1e-11, 1e-6),
|
||||
"eps_e": loguniform(1e4, 1e8),
|
||||
"initial_K": randint(400, 800),
|
||||
}
|
||||
elif model == "xgboost":
|
||||
param_grid = {
|
||||
"learning_rate": loguniform(1e-4, 1e-1),
|
||||
"max_depth": randint(3, 15),
|
||||
"n_estimators": randint(100, 1000),
|
||||
"subsample": loguniform(0.5, 1.0),
|
||||
"colsample_bytree": loguniform(0.5, 1.0),
|
||||
}
|
||||
clf = XGBClassifier(
|
||||
objective="multi:softprob" if task != "binary" else "binary:logistic",
|
||||
eval_metric="mlogloss" if task != "binary" else "logloss",
|
||||
random_state=42,
|
||||
tree_method="gpu_hist",
|
||||
device="cuda",
|
||||
)
|
||||
elif model == "rf":
|
||||
param_grid = {
|
||||
"max_depth": randint(5, 50),
|
||||
"n_estimators": randint(50, 500),
|
||||
}
|
||||
clf = RandomForestClassifier(random_state=42)
|
||||
elif model == "knn":
|
||||
param_grid = {
|
||||
"n_neighbors": randint(3, 15),
|
||||
"weights": ["uniform", "distance"],
|
||||
"algorithm": ["brute", "kd_tree", "ball_tree"],
|
||||
}
|
||||
clf = KNeighborsClassifier(random_state=42)
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||
if task == "binary":
|
||||
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
||||
else:
|
||||
metrics = [
|
||||
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro",
|
||||
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted"
|
||||
"f1_macro",
|
||||
"f1_weighted",
|
||||
"precision_macro",
|
||||
"precision_weighted",
|
||||
"recall_macro",
|
||||
"recall_weighted",
|
||||
"jaccard_micro",
|
||||
"jaccard_macro",
|
||||
"jaccard_weighted",
|
||||
|
|
@ -120,17 +171,17 @@ def random_cv(
|
|||
clf,
|
||||
param_grid,
|
||||
n_iter=n_iter,
|
||||
n_jobs=16,
|
||||
n_jobs=8,
|
||||
cv=cv,
|
||||
random_state=42,
|
||||
verbose=3,
|
||||
verbose=10,
|
||||
scoring=metrics,
|
||||
refit="f1" if task == "binary" else "f1_weighted",
|
||||
)
|
||||
|
||||
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
||||
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
||||
search.fit(X_train, y_train, max_iter=100)
|
||||
search.fit(X_train, y_train, max_iter=300)
|
||||
|
||||
print("Best parameters combination found:")
|
||||
best_parameters = search.best_estimator_.get_params()
|
||||
|
|
@ -144,29 +195,33 @@ def random_cv(
|
|||
results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
|
||||
|
||||
# Store the search settings
|
||||
# First convert the param_grid distributions to a serializable format
|
||||
param_grid_serializable = {}
|
||||
for key, dist in param_grid.items():
|
||||
if isinstance(dist, loguniform):
|
||||
param_grid_serializable[key] = {
|
||||
"distribution": "loguniform",
|
||||
"low": dist.a,
|
||||
"high": dist.b,
|
||||
}
|
||||
elif isinstance(dist, randint):
|
||||
param_grid_serializable[key] = {
|
||||
"distribution": "randint",
|
||||
"low": dist.a,
|
||||
"high": dist.b,
|
||||
}
|
||||
elif isinstance(dist, list):
|
||||
param_grid_serializable[key] = dist
|
||||
else:
|
||||
raise ValueError(f"Unknown distribution type for {key}: {type(dist)}")
|
||||
settings = {
|
||||
"task": task,
|
||||
"model": model,
|
||||
"grid": grid,
|
||||
"level": level,
|
||||
"random_state": 42,
|
||||
"n_iter": n_iter,
|
||||
"param_grid": {
|
||||
"eps_cl": {
|
||||
"distribution": "loguniform",
|
||||
"low": param_grid["eps_cl"].a,
|
||||
"high": param_grid["eps_cl"].b,
|
||||
},
|
||||
"eps_e": {
|
||||
"distribution": "loguniform",
|
||||
"low": param_grid["eps_e"].a,
|
||||
"high": param_grid["eps_e"].b,
|
||||
},
|
||||
"initial_K": {
|
||||
"distribution": "randint",
|
||||
"low": param_grid["initial_K"].a,
|
||||
"high": param_grid["initial_K"].b,
|
||||
},
|
||||
},
|
||||
"param_grid": param_grid_serializable,
|
||||
"cv_splits": cv.get_n_splits(),
|
||||
"metrics": metrics,
|
||||
"classes": labels,
|
||||
|
|
@ -195,46 +250,47 @@ def random_cv(
|
|||
results.to_parquet(results_file)
|
||||
|
||||
# Get the inner state of the best estimator
|
||||
best_estimator = search.best_estimator_
|
||||
# Annotate the state with xarray metadata
|
||||
features = X_data.columns.tolist()
|
||||
boxes = list(range(best_estimator.K_))
|
||||
box_centers = xr.DataArray(
|
||||
best_estimator.S_.cpu().numpy(),
|
||||
dims=["feature", "box"],
|
||||
coords={"feature": features, "box": boxes},
|
||||
name="box_centers",
|
||||
attrs={"description": "Centers of the boxes in feature space."},
|
||||
)
|
||||
box_assignments = xr.DataArray(
|
||||
best_estimator.Lambda_.cpu().numpy(),
|
||||
dims=["class", "box"],
|
||||
coords={"class": labels, "box": boxes},
|
||||
name="box_assignments",
|
||||
attrs={"description": "Assignments of samples to boxes."},
|
||||
)
|
||||
feature_weights = xr.DataArray(
|
||||
best_estimator.W_.cpu().numpy(),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_weights",
|
||||
attrs={"description": "Feature weights for each box."},
|
||||
)
|
||||
state = xr.Dataset(
|
||||
{
|
||||
"box_centers": box_centers,
|
||||
"box_assignments": box_assignments,
|
||||
"feature_weights": feature_weights,
|
||||
},
|
||||
attrs={
|
||||
"description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.",
|
||||
"grid": grid,
|
||||
"level": level,
|
||||
},
|
||||
)
|
||||
state_file = results_dir / "best_estimator_state.nc"
|
||||
print(f"Storing best estimator state to {state_file}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
if model == "espa":
|
||||
best_estimator = search.best_estimator_
|
||||
# Annotate the state with xarray metadata
|
||||
features = X_data.columns.tolist()
|
||||
boxes = list(range(best_estimator.K_))
|
||||
box_centers = xr.DataArray(
|
||||
best_estimator.S_.cpu().numpy(),
|
||||
dims=["feature", "box"],
|
||||
coords={"feature": features, "box": boxes},
|
||||
name="box_centers",
|
||||
attrs={"description": "Centers of the boxes in feature space."},
|
||||
)
|
||||
box_assignments = xr.DataArray(
|
||||
best_estimator.Lambda_.cpu().numpy(),
|
||||
dims=["class", "box"],
|
||||
coords={"class": labels, "box": boxes},
|
||||
name="box_assignments",
|
||||
attrs={"description": "Assignments of samples to boxes."},
|
||||
)
|
||||
feature_weights = xr.DataArray(
|
||||
best_estimator.W_.cpu().numpy(),
|
||||
dims=["feature"],
|
||||
coords={"feature": features},
|
||||
name="feature_weights",
|
||||
attrs={"description": "Feature weights for each box."},
|
||||
)
|
||||
state = xr.Dataset(
|
||||
{
|
||||
"box_centers": box_centers,
|
||||
"box_assignments": box_assignments,
|
||||
"feature_weights": feature_weights,
|
||||
},
|
||||
attrs={
|
||||
"description": "Inner state of the best ESPAClassifier from RandomizedSearchCV.",
|
||||
"grid": grid,
|
||||
"level": level,
|
||||
},
|
||||
)
|
||||
state_file = results_dir / "best_estimator_state.nc"
|
||||
print(f"Storing best estimator state to {state_file}")
|
||||
state.to_netcdf(state_file, engine="h5netcdf")
|
||||
|
||||
# Predict probabilities for all cells
|
||||
print("Predicting probabilities for all cells...")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue