Refactor spatial aggregations
This commit is contained in:
parent
1a71883999
commit
e5382670ec
5 changed files with 457 additions and 240 deletions
254
src/entropice/aggregators.py
Normal file
254
src/entropice/aggregators.py
Normal file
|
|
@ -0,0 +1,254 @@
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import cupy_xarray
|
||||||
|
import geopandas as gpd
|
||||||
|
import numpy as np
|
||||||
|
import odc.geo.geobox
|
||||||
|
import shapely
|
||||||
|
import shapely.ops
|
||||||
|
import xarray as xr
|
||||||
|
import xdggs
|
||||||
|
import xvec
|
||||||
|
from rich.progress import track
|
||||||
|
from shapely.geometry import LineString, Polygon
|
||||||
|
from stopuhr import stopwatch
|
||||||
|
from xdggs.healpix import HealpixInfo
|
||||||
|
|
||||||
|
from entropice import grids
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Aggregations:
|
||||||
|
mean: bool = True
|
||||||
|
sum: bool = False
|
||||||
|
std: bool = False
|
||||||
|
min: bool = False
|
||||||
|
max: bool = False
|
||||||
|
quantiles: list[float] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
def varnames(self, vars: list[str] | str) -> list[str]:
|
||||||
|
if isinstance(vars, str):
|
||||||
|
vars = [vars]
|
||||||
|
agg_vars = []
|
||||||
|
for var in vars:
|
||||||
|
if self.mean:
|
||||||
|
agg_vars.append(f"{var}_mean")
|
||||||
|
if self.sum:
|
||||||
|
agg_vars.append(f"{var}_sum")
|
||||||
|
if self.std:
|
||||||
|
agg_vars.append(f"{var}_std")
|
||||||
|
if self.min:
|
||||||
|
agg_vars.append(f"{var}_min")
|
||||||
|
if self.max:
|
||||||
|
agg_vars.append(f"{var}_max")
|
||||||
|
for q in self.quantiles:
|
||||||
|
q_int = int(q * 100)
|
||||||
|
agg_vars.append(f"{var}_p{q_int}")
|
||||||
|
return agg_vars
|
||||||
|
|
||||||
|
|
||||||
|
def _crosses_antimeridian(geom: Polygon) -> bool:
|
||||||
|
coords = shapely.get_coordinates(geom)
|
||||||
|
crosses_any_meridian = (coords[:, 0] > 0).any() and (coords[:, 0] < 0).any()
|
||||||
|
return crosses_any_meridian and abs(coords[:, 0]).max() > 90
|
||||||
|
|
||||||
|
|
||||||
|
def _split_antimeridian_cell(geom: Polygon) -> list[Polygon]:
|
||||||
|
# Assumes that it is a antimeridian hex
|
||||||
|
coords = shapely.get_coordinates(geom)
|
||||||
|
for i in range(coords.shape[0]):
|
||||||
|
if coords[i, 0] < 0:
|
||||||
|
coords[i, 0] += 360
|
||||||
|
geom = Polygon(coords)
|
||||||
|
antimeridian = LineString([[180, -90], [180, 90]])
|
||||||
|
polys = shapely.ops.split(geom, antimeridian)
|
||||||
|
return list(polys.geoms)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
|
||||||
|
enclosing = geobox.enclosing(geom)
|
||||||
|
x, y = enclosing.shape
|
||||||
|
if x <= 1 or y <= 1:
|
||||||
|
return False
|
||||||
|
roi: tuple[slice, slice] = geobox.overlap_roi(enclosing)
|
||||||
|
roix, roiy = roi
|
||||||
|
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Getting corrected geometries", log=False)
|
||||||
|
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox, str]) -> list[odc.geo.Geometry]:
|
||||||
|
geom, gbox, crs = inp
|
||||||
|
# cell.geometry is a shapely Polygon
|
||||||
|
if not _crosses_antimeridian(geom):
|
||||||
|
geoms = [geom]
|
||||||
|
# Split geometry in case it crossed antimeridian
|
||||||
|
else:
|
||||||
|
geoms = _split_antimeridian_cell(geom)
|
||||||
|
|
||||||
|
geoms = [odc.geo.Geometry(g, crs=crs) for g in geoms]
|
||||||
|
geoms = list(filter(lambda g: _check_geom(gbox, g), geoms))
|
||||||
|
return geoms
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Correcting geometries")
|
||||||
|
def get_corrected_geometries(grid_gdf: gpd.GeoDataFrame, gbox: odc.geo.geobox.GeoBox):
|
||||||
|
"""Get corrected geometries for antimeridian-crossing polygons.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_gdf (gpd.GeoDataFrame): Grid GeoDataFrame.
|
||||||
|
gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[list[odc.geo.Geometry]]: List of corrected, georeferenced geometries.
|
||||||
|
|
||||||
|
"""
|
||||||
|
with ProcessPoolExecutor(max_workers=20) as executor:
|
||||||
|
cell_geometries = list(
|
||||||
|
executor.map(_get_corrected_geoms, [(row.geometry, gbox, grid_gdf.crs) for _, row in grid_gdf.iterrows()])
|
||||||
|
)
|
||||||
|
return cell_geometries
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Aggregating cell data", log=False)
|
||||||
|
def _agg_cell_data(flattened: xr.Dataset, aggregations: _Aggregations):
|
||||||
|
cell_data = {}
|
||||||
|
for var in flattened.data_vars:
|
||||||
|
if aggregations.mean:
|
||||||
|
cell_data[f"{var}_mean"] = flattened[var].mean(dim="z", skipna=True)
|
||||||
|
if aggregations.sum:
|
||||||
|
cell_data[f"{var}_sum"] = flattened[var].sum(dim="z", skipna=True)
|
||||||
|
if aggregations.std:
|
||||||
|
cell_data[f"{var}_std"] = flattened[var].std(dim="z", skipna=True)
|
||||||
|
if aggregations.min:
|
||||||
|
cell_data[f"{var}_min"] = flattened[var].min(dim="z", skipna=True)
|
||||||
|
if aggregations.max:
|
||||||
|
cell_data[f"{var}_max"] = flattened[var].max(dim="z", skipna=True)
|
||||||
|
if len(aggregations.quantiles) > 0:
|
||||||
|
quantile_values = flattened[var].quantile(
|
||||||
|
q=aggregations.quantiles,
|
||||||
|
dim="z",
|
||||||
|
skipna=True,
|
||||||
|
)
|
||||||
|
for q, qv in zip(aggregations.quantiles, quantile_values):
|
||||||
|
q_int = int(q * 100)
|
||||||
|
cell_data[f"{var}_p{q_int}"] = qv
|
||||||
|
# Transform to numpy arrays
|
||||||
|
for key in cell_data:
|
||||||
|
cell_data[key] = cell_data[key].cupy.as_numpy().values
|
||||||
|
return cell_data
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Extracting cell data", log=False)
|
||||||
|
def _extract_cell_data(ds: xr.Dataset, geom: odc.geo.Geometry, aggregations: _Aggregations):
|
||||||
|
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
|
||||||
|
cropped: xr.Dataset = ds.odc.crop(geom).drop_vars("spatial_ref")
|
||||||
|
flattened = cropped.stack(z=spatdims)
|
||||||
|
if flattened.z.size > 3000:
|
||||||
|
flattened = flattened.cupy.as_cupy()
|
||||||
|
cell_data = _agg_cell_data(flattened, aggregations)
|
||||||
|
return cell_data
|
||||||
|
|
||||||
|
# with np.errstate(divide="ignore", invalid="ignore"):
|
||||||
|
# cell_data = cropped.mean(spatdims)
|
||||||
|
# return {var: cell_data[var].values for var in cell_data.data_vars}
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Extracting split cell data", log=False)
|
||||||
|
def _extract_split_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry], aggregations: _Aggregations):
|
||||||
|
spatdims = ["latitude", "longitude"] if "latitude" in ds.dims and "longitude" in ds.dims else ["y", "x"]
|
||||||
|
cropped: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(ds, geoms)]
|
||||||
|
flattened = xr.concat([c.stack(z=spatdims) for c in cropped], dim="z")
|
||||||
|
cell_data = _agg_cell_data(flattened, aggregations)
|
||||||
|
return cell_data
|
||||||
|
|
||||||
|
# partial_counts = [part.notnull().sum(dim=spatdims) for part in parts]
|
||||||
|
# with np.errstate(divide="ignore", invalid="ignore"):
|
||||||
|
# partial_means = [part.sum(dim=spatdims) for part in parts]
|
||||||
|
# n = xr.concat(partial_counts, dim="part").sum("part")
|
||||||
|
# cell_data = xr.concat(partial_means, dim="part").sum("part") / n
|
||||||
|
# return {var: cell_data[var].values for var in cell_data.data_vars}
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Aligning data with grid")
|
||||||
|
def _align_data(
|
||||||
|
grid_gdf: gpd.GeoDataFrame,
|
||||||
|
unaligned: xr.Dataset,
|
||||||
|
aggregations: _Aggregations,
|
||||||
|
) -> dict[str, np.ndarray]:
|
||||||
|
# Persist the dataset, as all the operations here MUST NOT be lazy
|
||||||
|
unaligned = unaligned.load()
|
||||||
|
|
||||||
|
cell_geometries = get_corrected_geometries(grid_gdf, unaligned.odc.geobox)
|
||||||
|
other_dims_shape = tuple(
|
||||||
|
[unaligned.sizes[dim] for dim in unaligned.dims if dim not in ["y", "x", "latitude", "longitude"]]
|
||||||
|
)
|
||||||
|
data_shape = (len(cell_geometries), *other_dims_shape)
|
||||||
|
data = {var: np.full(data_shape, np.nan, dtype=np.float32) for var in aggregations.varnames(unaligned.data_vars)}
|
||||||
|
with ProcessPoolExecutor(max_workers=10) as executor:
|
||||||
|
futures = {}
|
||||||
|
for i, geoms in track(
|
||||||
|
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
|
||||||
|
):
|
||||||
|
if len(geoms) == 0:
|
||||||
|
continue
|
||||||
|
elif len(geoms) == 1:
|
||||||
|
geom = geoms[0]
|
||||||
|
# Reduce the amount of data needed to be sent to the worker
|
||||||
|
# Since we dont mask the data, only isel operations are done here
|
||||||
|
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
|
||||||
|
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
|
||||||
|
fut = executor.submit(_extract_cell_data, unaligned_subset, geom, aggregations)
|
||||||
|
else:
|
||||||
|
# Same as above but for multiple parts
|
||||||
|
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
|
||||||
|
fut = executor.submit(_extract_split_cell_data, unaligned_subsets, geoms, aggregations)
|
||||||
|
futures[fut] = i
|
||||||
|
|
||||||
|
for future in track(
|
||||||
|
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
|
||||||
|
):
|
||||||
|
i = futures[future]
|
||||||
|
cell_data = future.result()
|
||||||
|
for var in unaligned.data_vars:
|
||||||
|
data[var][i, :] = cell_data[var]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_raster_into_grid(
|
||||||
|
raster: xr.Dataset,
|
||||||
|
grid_gdf: gpd.GeoDataFrame,
|
||||||
|
aggregations: _Aggregations,
|
||||||
|
grid: Literal["hex", "healpix"],
|
||||||
|
level: int,
|
||||||
|
):
|
||||||
|
aligned = _align_data(grid_gdf, raster, aggregations)
|
||||||
|
|
||||||
|
cell_ids = grids.get_cell_ids(grid, level)
|
||||||
|
|
||||||
|
dims = ["cell_ids"]
|
||||||
|
coords = {"cell_ids": cell_ids}
|
||||||
|
for dim in raster.dims:
|
||||||
|
if dim not in ["y", "x", "latitude", "longitude"]:
|
||||||
|
dims.append(dim)
|
||||||
|
coords[dim] = raster.coords[dim]
|
||||||
|
|
||||||
|
data_vars = {var: (dims, values) for var, values in aligned.items()}
|
||||||
|
ongrid = xr.Dataset(
|
||||||
|
data_vars,
|
||||||
|
coords=coords,
|
||||||
|
)
|
||||||
|
gridinfo = {
|
||||||
|
"grid_name": "h3" if grid == "hex" else grid,
|
||||||
|
"level": level,
|
||||||
|
}
|
||||||
|
if grid == "healpix":
|
||||||
|
gridinfo["indexing_scheme"] = "nested"
|
||||||
|
ongrid.cell_ids.attrs = gridinfo
|
||||||
|
for var in raster.data_vars:
|
||||||
|
for v in aggregations.varnames(var):
|
||||||
|
ongrid[v].attrs = raster[var].attrs
|
||||||
|
|
||||||
|
ongrid = xdggs.decode(ongrid)
|
||||||
|
return ongrid
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
import cupy_xarray
|
import cupy_xarray
|
||||||
|
|
@ -12,6 +13,7 @@ import icechunk.xarray
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import smart_geocubes
|
import smart_geocubes
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
import xdggs
|
||||||
import xrspatial
|
import xrspatial
|
||||||
import zarr
|
import zarr
|
||||||
from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt
|
from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt
|
||||||
|
|
@ -22,7 +24,9 @@ from xrspatial.curvature import _run_cupy as curvature_cupy
|
||||||
from xrspatial.slope import _run_cupy as slope_cupy
|
from xrspatial.slope import _run_cupy as slope_cupy
|
||||||
from zarr.codecs import BloscCodec
|
from zarr.codecs import BloscCodec
|
||||||
|
|
||||||
from entropice.paths import arcticdem_store
|
from entropice import codecs, grids, watermask
|
||||||
|
from entropice.aggregators import _Aggregations, aggregate_raster_into_grid
|
||||||
|
from entropice.paths import get_arcticdem_stores
|
||||||
|
|
||||||
traceback.install(show_locals=True, suppress=[cyclopts])
|
traceback.install(show_locals=True, suppress=[cyclopts])
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -34,6 +38,7 @@ cli = cyclopts.App(name="arcticdem")
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def download():
|
def download():
|
||||||
|
arcticdem_store = get_arcticdem_stores()
|
||||||
adem = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
adem = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||||
adem.create(exists_ok=True)
|
adem.create(exists_ok=True)
|
||||||
with stopwatch("Download ArcticDEM data"):
|
with stopwatch("Download ArcticDEM data"):
|
||||||
|
|
@ -172,7 +177,7 @@ def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) ->
|
||||||
tpi_large = tpi_cupy(chunk, large_kernels)
|
tpi_large = tpi_cupy(chunk, large_kernels)
|
||||||
# Slope
|
# Slope
|
||||||
slope = slope_cupy(chunk, res, res)
|
slope = slope_cupy(chunk, res, res)
|
||||||
# Aspect
|
# Aspect & correction
|
||||||
aspect = aspect_cupy(chunk)
|
aspect = aspect_cupy(chunk)
|
||||||
xx, yy = _get_xy_chunk(chunk, x, y, block_info)
|
xx, yy = _get_xy_chunk(chunk, x, y, block_info)
|
||||||
aspect_correction = cp.arctan2(yy, xx) * (180 / cp.pi) + 90
|
aspect_correction = cp.arctan2(yy, xx) * (180 / cp.pi) + 90
|
||||||
|
|
@ -235,12 +240,6 @@ def _enrich(terrain: xr.DataArray):
|
||||||
enriched_da = dask.array.map_overlap(
|
enriched_da = dask.array.map_overlap(
|
||||||
_enrich_chunk,
|
_enrich_chunk,
|
||||||
terrain.data,
|
terrain.data,
|
||||||
# dask.array.from_array(terrain.y.data.reshape(-1, 1), chunks=(terrain.data.chunks[0][0], 3600)).repeat(
|
|
||||||
# terrain.x.size, axis=1
|
|
||||||
# ),
|
|
||||||
# dask.array.from_array(terrain.x.data.reshape(1, -1), chunks=(3600, terrain.data.chunks[1][0])).repeat(
|
|
||||||
# terrain.y.size, axis=0
|
|
||||||
# ),
|
|
||||||
x=terrain.x.to_numpy(),
|
x=terrain.x.to_numpy(),
|
||||||
y=terrain.y.to_numpy(),
|
y=terrain.y.to_numpy(),
|
||||||
depth=15, # large_kernels.size_px
|
depth=15, # large_kernels.size_px
|
||||||
|
|
@ -273,14 +272,13 @@ def enrich():
|
||||||
print(client)
|
print(client)
|
||||||
print(client.dashboard_link)
|
print(client.dashboard_link)
|
||||||
|
|
||||||
|
arcticdem_store = get_arcticdem_stores()
|
||||||
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||||
|
|
||||||
# Garbage collect from previous runs
|
# Garbage collect from previous runs
|
||||||
accessor.repo.garbage_collect(datetime.datetime.now(datetime.UTC))
|
accessor.repo.garbage_collect(datetime.datetime.now(datetime.UTC))
|
||||||
|
|
||||||
adem = accessor.open_xarray()
|
adem = accessor.open_xarray()
|
||||||
# session = adem.repo.readonly_session("main")
|
|
||||||
# adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref")
|
|
||||||
del adem.y.attrs["_FillValue"]
|
del adem.y.attrs["_FillValue"]
|
||||||
del adem.x.attrs["_FillValue"]
|
del adem.x.attrs["_FillValue"]
|
||||||
enriched, new_features = _enrich(adem.dem)
|
enriched, new_features = _enrich(adem.dem)
|
||||||
|
|
@ -289,10 +287,9 @@ def enrich():
|
||||||
adem[feature] = enriched.sel(feature=feature)
|
adem[feature] = enriched.sel(feature=feature)
|
||||||
print(adem[new_features])
|
print(adem[new_features])
|
||||||
# subset = adem[new_features].isel(x=slice(190800, 220000), y=slice(61200, 100000))
|
# subset = adem[new_features].isel(x=slice(190800, 220000), y=slice(61200, 100000))
|
||||||
encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features}
|
|
||||||
# subset.to_zarr("test2.zarr", mode="w", encoding=encodings)
|
|
||||||
|
|
||||||
session = accessor.repo.writable_session("main")
|
session = accessor.repo.writable_session("main")
|
||||||
|
encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features}
|
||||||
icechunk.xarray.to_icechunk(
|
icechunk.xarray.to_icechunk(
|
||||||
adem[new_features],
|
adem[new_features],
|
||||||
session,
|
session,
|
||||||
|
|
@ -304,5 +301,32 @@ def enrich():
|
||||||
print("Enrichment complete.")
|
print("Enrichment complete.")
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def aggregate(grid: Literal["hex", "healpix"], level: int):
|
||||||
|
with stopwatch(f"Loading {grid} grid at level {level}"):
|
||||||
|
grid_gdf = grids.open(grid, level)
|
||||||
|
# ? Mask out water, since we don't want to aggregate over oceans
|
||||||
|
grid_gdf = watermask.clip_grid(grid_gdf)
|
||||||
|
|
||||||
|
arcticdem_store = get_arcticdem_stores()
|
||||||
|
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||||
|
adem = accessor.open_xarray()
|
||||||
|
assert {"x", "y"} == set(adem.dims)
|
||||||
|
assert adem.odc.crs == "EPSG:3413"
|
||||||
|
|
||||||
|
aggregations = _Aggregations(
|
||||||
|
mean=True,
|
||||||
|
sum=False,
|
||||||
|
std=True,
|
||||||
|
min=True,
|
||||||
|
max=True,
|
||||||
|
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
|
||||||
|
)
|
||||||
|
aggregated = aggregate_raster_into_grid(adem, grid_gdf, aggregations, grid, level)
|
||||||
|
store = get_arcticdem_stores(grid, level)
|
||||||
|
with stopwatch(f"Saving spatially aggregated ArcticDEM data to {store}"):
|
||||||
|
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# ruff: noqa: PD003, PD011
|
# ruff: noqa: PD003
|
||||||
"""Download and preprocess ERA5 data.
|
"""Download and preprocess ERA5 data.
|
||||||
|
|
||||||
Variables of Interest:
|
Variables of Interest:
|
||||||
|
|
@ -79,7 +79,6 @@ Date: June to October 2025
|
||||||
|
|
||||||
import cProfile
|
import cProfile
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
|
|
@ -88,18 +87,15 @@ import numpy as np
|
||||||
import odc.geo
|
import odc.geo
|
||||||
import odc.geo.xr
|
import odc.geo.xr
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import shapely
|
|
||||||
import shapely.ops
|
|
||||||
import ultraplot as uplt
|
import ultraplot as uplt
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
import xdggs
|
import xdggs
|
||||||
import xvec
|
import xvec
|
||||||
from rich import pretty, print, traceback
|
from rich import pretty, print, traceback
|
||||||
from rich.progress import track
|
|
||||||
from shapely.geometry import LineString, Polygon
|
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice import codecs, grids, watermask
|
from entropice import codecs, grids, watermask
|
||||||
|
from entropice.aggregators import _Aggregations, aggregate_raster_into_grid
|
||||||
from entropice.paths import FIGURES_DIR, get_era5_stores
|
from entropice.paths import FIGURES_DIR, get_era5_stores
|
||||||
from entropice.xvec import to_xvec
|
from entropice.xvec import to_xvec
|
||||||
|
|
||||||
|
|
@ -647,145 +643,10 @@ def viz(
|
||||||
# ===========================
|
# ===========================
|
||||||
|
|
||||||
|
|
||||||
def _crosses_antimeridian(geom: Polygon) -> bool:
|
|
||||||
coords = shapely.get_coordinates(geom)
|
|
||||||
crosses_any_meridian = (coords[:, 0] > 0).any() and (coords[:, 0] < 0).any()
|
|
||||||
return crosses_any_meridian and abs(coords[:, 0]).max() > 90
|
|
||||||
|
|
||||||
|
|
||||||
def _split_antimeridian_cell(geom: Polygon) -> list[Polygon]:
|
|
||||||
# Assumes that it is a antimeridian hex
|
|
||||||
coords = shapely.get_coordinates(geom)
|
|
||||||
for i in range(coords.shape[0]):
|
|
||||||
if coords[i, 0] < 0:
|
|
||||||
coords[i, 0] += 360
|
|
||||||
geom = Polygon(coords)
|
|
||||||
antimeridian = LineString([[180, -90], [180, 90]])
|
|
||||||
polys = shapely.ops.split(geom, antimeridian)
|
|
||||||
return list(polys.geoms)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
|
|
||||||
enclosing = geobox.enclosing(geom)
|
|
||||||
x, y = enclosing.shape
|
|
||||||
if x <= 1 or y <= 1:
|
|
||||||
return False
|
|
||||||
roi: tuple[slice, slice] = geobox.overlap_roi(enclosing)
|
|
||||||
roix, roiy = roi
|
|
||||||
return (roix.stop - roix.start) > 1 and (roiy.stop - roiy.start) > 1
|
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Getting corrected geometries", log=False)
|
|
||||||
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox]) -> list[odc.geo.Geometry]:
|
|
||||||
"""Get corrected geometries for antimeridian-crossing polygons.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
geom (Polygon): Input polygon geometry.
|
|
||||||
gbox (odc.geo.geobox.GeoBox): GeoBox for spatial reference.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[odc.geo.Geometry]: List of corrected, georeferenced geometries.
|
|
||||||
|
|
||||||
"""
|
|
||||||
geom, gbox = inp
|
|
||||||
# cell.geometry is a shapely Polygon
|
|
||||||
if not _crosses_antimeridian(geom):
|
|
||||||
geoms = [geom]
|
|
||||||
# Split geometry in case it crossed antimeridian
|
|
||||||
else:
|
|
||||||
geoms = _split_antimeridian_cell(geom)
|
|
||||||
|
|
||||||
geoms = [odc.geo.Geometry(g, crs="epsg:4326") for g in geoms]
|
|
||||||
geoms = list(filter(lambda g: _check_geom(gbox, g), geoms))
|
|
||||||
return geoms
|
|
||||||
|
|
||||||
|
|
||||||
def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
|
def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
|
||||||
return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude")
|
return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude")
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Extracting cell data", log=False)
|
|
||||||
def _extract_cell_data(ds, geom):
|
|
||||||
cropped = ds.odc.crop(geom).drop_vars("spatial_ref")
|
|
||||||
with np.errstate(divide="ignore", invalid="ignore"):
|
|
||||||
cell_data = cropped.mean(["latitude", "longitude"])
|
|
||||||
return {var: cell_data[var].values for var in cell_data.data_vars}
|
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Extracting split cell data", log=False)
|
|
||||||
def _extract_split_cell_data(dss: xr.Dataset, geoms):
|
|
||||||
parts: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(dss, geoms)]
|
|
||||||
partial_counts = [part.notnull().sum(dim=["latitude", "longitude"]) for part in parts]
|
|
||||||
with np.errstate(divide="ignore", invalid="ignore"):
|
|
||||||
partial_means = [part.sum(["latitude", "longitude"]) for part in parts]
|
|
||||||
n = xr.concat(partial_counts, dim="part").sum("part")
|
|
||||||
cell_data = xr.concat(partial_means, dim="part").sum("part") / n
|
|
||||||
return {var: cell_data[var].values for var in cell_data.data_vars}
|
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Aligning data")
|
|
||||||
def _align_data(cell_geometries: list[list[odc.geo.Geometry]], unaligned: xr.Dataset) -> dict[str, np.ndarray]:
|
|
||||||
# Persist the dataset, as all the operations here MUST NOT be lazy
|
|
||||||
unaligned = unaligned.load()
|
|
||||||
|
|
||||||
data = {
|
|
||||||
var: np.full((len(cell_geometries), len(unaligned.time)), np.nan, dtype=np.float32)
|
|
||||||
for var in unaligned.data_vars
|
|
||||||
}
|
|
||||||
with ProcessPoolExecutor(max_workers=10) as executor:
|
|
||||||
futures = {}
|
|
||||||
for i, geoms in track(
|
|
||||||
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
|
|
||||||
):
|
|
||||||
if len(geoms) == 0:
|
|
||||||
continue
|
|
||||||
elif len(geoms) == 1:
|
|
||||||
geom = geoms[0]
|
|
||||||
# Reduce the amount of data needed to be sent to the worker
|
|
||||||
# Since we dont mask the data, only isel operations are done here
|
|
||||||
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
|
|
||||||
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
|
|
||||||
futures[executor.submit(_extract_cell_data, unaligned_subset, geom)] = i
|
|
||||||
else:
|
|
||||||
# Same as above but for multiple parts
|
|
||||||
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
|
|
||||||
futures[executor.submit(_extract_split_cell_data, unaligned_subsets, geoms)] = i
|
|
||||||
|
|
||||||
for future in track(
|
|
||||||
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
|
|
||||||
):
|
|
||||||
i = futures[future]
|
|
||||||
cell_data = future.result()
|
|
||||||
for var in unaligned.data_vars:
|
|
||||||
data[var][i, :] = cell_data[var]
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Creating aligned dataset", log=False)
|
|
||||||
def _create_aligned(
|
|
||||||
ds: xr.Dataset, data: dict[str, np.ndarray], grid: Literal["hex", "healpix"], level: int
|
|
||||||
) -> xr.Dataset:
|
|
||||||
cell_ids = grids.get_cell_ids(grid, level)
|
|
||||||
data_vars = {var: (["cell_ids", "time"], values) for var, values in data.items()}
|
|
||||||
aligned = xr.Dataset(
|
|
||||||
data_vars,
|
|
||||||
coords={"cell_ids": cell_ids, "time": ds.time},
|
|
||||||
)
|
|
||||||
gridinfo = {
|
|
||||||
"grid_name": "h3" if grid == "hex" else grid,
|
|
||||||
"level": level,
|
|
||||||
}
|
|
||||||
if grid == "healpix":
|
|
||||||
gridinfo["indexing_scheme"] = "nested"
|
|
||||||
aligned.cell_ids.attrs = gridinfo
|
|
||||||
for var in ds.data_vars:
|
|
||||||
aligned[var].attrs = ds[var].attrs
|
|
||||||
|
|
||||||
aligned = aligned.chunk({"cell_ids": min(len(aligned.cell_ids), 10000), "time": len(aligned.time)})
|
|
||||||
aligned = xdggs.decode(aligned)
|
|
||||||
return aligned
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command
|
@cli.command
|
||||||
def spatial_agg(
|
def spatial_agg(
|
||||||
grid: Literal["hex", "healpix"],
|
grid: Literal["hex", "healpix"],
|
||||||
|
|
@ -816,18 +677,28 @@ def spatial_agg(
|
||||||
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
|
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
|
||||||
unaligned = _correct_longs(unaligned)
|
unaligned = _correct_longs(unaligned)
|
||||||
|
|
||||||
with stopwatch("Precomputing cell geometries"):
|
# with stopwatch("Precomputing cell geometries"):
|
||||||
with ProcessPoolExecutor(max_workers=20) as executor:
|
# with ProcessPoolExecutor(max_workers=20) as executor:
|
||||||
cell_geometries = list(
|
# cell_geometries = list(
|
||||||
executor.map(
|
# executor.map(
|
||||||
_get_corrected_geoms,
|
# _get_corrected_geoms,
|
||||||
[(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()],
|
# [(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()],
|
||||||
)
|
# )
|
||||||
)
|
# )
|
||||||
|
# data = _align_data(cell_geometries, unaligned)
|
||||||
|
# aggregated = _create_aligned(unaligned, data, grid, level)
|
||||||
|
|
||||||
data = _align_data(cell_geometries, unaligned)
|
aggregations = _Aggregations(
|
||||||
|
mean=True,
|
||||||
|
sum=False,
|
||||||
|
std=True,
|
||||||
|
min=True,
|
||||||
|
max=True,
|
||||||
|
quantiles=[0.01, 0.05, 0.25, 0.75, 0.95, 0.99],
|
||||||
|
)
|
||||||
|
aggregated = aggregate_raster_into_grid(unaligned, grid_gdf, aggregations, grid, level)
|
||||||
|
|
||||||
aggregated = _create_aligned(unaligned, data, grid, level)
|
aggregated = aggregated.chunk({"cell_ids": min(len(aggregated.cell_ids), 10000), "time": len(aggregated.time)})
|
||||||
store = get_era5_stores(agg, grid, level)
|
store = get_era5_stores(agg, grid, level)
|
||||||
with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"):
|
with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"):
|
||||||
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ GRIDS_DIR = DATA_DIR / "grids"
|
||||||
FIGURES_DIR = Path("figures")
|
FIGURES_DIR = Path("figures")
|
||||||
DARTS_DIR = DATA_DIR / "darts"
|
DARTS_DIR = DATA_DIR / "darts"
|
||||||
ERA5_DIR = DATA_DIR / "era5"
|
ERA5_DIR = DATA_DIR / "era5"
|
||||||
|
ARCTICDEM_DIR = DATA_DIR / "arcticdem"
|
||||||
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
||||||
WATERMASK_DIR = DATA_DIR / "watermask"
|
WATERMASK_DIR = DATA_DIR / "watermask"
|
||||||
TRAINING_DIR = DATA_DIR / "training"
|
TRAINING_DIR = DATA_DIR / "training"
|
||||||
|
|
@ -22,12 +23,12 @@ GRIDS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
DARTS_DIR.mkdir(parents=True, exist_ok=True)
|
DARTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
ERA5_DIR.mkdir(parents=True, exist_ok=True)
|
ERA5_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
|
WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
TRAINING_DIR.mkdir(parents=True, exist_ok=True)
|
TRAINING_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
arcticdem_store = DATA_DIR / "arcticdem32m.icechunk.zarr"
|
|
||||||
|
|
||||||
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
||||||
|
|
||||||
|
|
@ -83,6 +84,17 @@ def get_era5_stores(
|
||||||
return aligned_path
|
return aligned_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_arcticdem_stores(
|
||||||
|
grid: Literal["hex", "healpix"] | None = None,
|
||||||
|
level: int | None = None,
|
||||||
|
):
|
||||||
|
if grid is None or level is None:
|
||||||
|
return DATA_DIR / "arcticdem32m.icechunk.zarr"
|
||||||
|
gridname = _get_gridname(grid, level)
|
||||||
|
aligned_path = ARCTICDEM_DIR / f"{gridname}_dem.zarr"
|
||||||
|
return aligned_path
|
||||||
|
|
||||||
|
|
||||||
def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet"
|
dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet"
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806
|
||||||
"""Training dataset preparation and model training."""
|
"""Training of classification models training."""
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
@ -10,12 +10,15 @@ import pandas as pd
|
||||||
import toml
|
import toml
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from cuml.ensemble import RandomForestClassifier
|
||||||
|
from cuml.neighbors import KNeighborsClassifier
|
||||||
from entropy import ESPAClassifier
|
from entropy import ESPAClassifier
|
||||||
from rich import pretty, traceback
|
from rich import pretty, traceback
|
||||||
from scipy.stats import loguniform, randint
|
from scipy.stats import loguniform, randint
|
||||||
from sklearn import set_config
|
from sklearn import set_config
|
||||||
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
from xgboost import XGBClassifier
|
||||||
|
|
||||||
from entropice.inference import predict_proba
|
from entropice.inference import predict_proba
|
||||||
from entropice.paths import (
|
from entropice.paths import (
|
||||||
|
|
@ -29,13 +32,13 @@ pretty.install()
|
||||||
set_config(array_api_dispatch=True)
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
|
||||||
def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "multi"] = "binary"):
|
def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["binary", "count", "density"] = "binary"):
|
||||||
"""Create X and y data from the training dataset.
|
"""Create X and y data from the training dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
|
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
|
Tuple[pd.DataFrame, pd.DataFrame, pd.Series, list]: The data, Features (X), labels (y), and label names.
|
||||||
|
|
@ -51,7 +54,7 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b
|
||||||
if task == "binary":
|
if task == "binary":
|
||||||
labels = ["No RTS", "RTS"]
|
labels = ["No RTS", "RTS"]
|
||||||
y_data = data.loc[X_data.index, "darts_has_rts"]
|
y_data = data.loc[X_data.index, "darts_has_rts"]
|
||||||
else:
|
elif task == "count":
|
||||||
# Put into n categories (log scaled)
|
# Put into n categories (log scaled)
|
||||||
y_data = data.loc[X_data.index, "darts_rts_count"]
|
y_data = data.loc[X_data.index, "darts_rts_count"]
|
||||||
n_categories = 5
|
n_categories = 5
|
||||||
|
|
@ -63,6 +66,17 @@ def create_xy_data(grid: Literal["hex", "healpix"], level: int, task: Literal["b
|
||||||
y_data = pd.cut(y_data, bins=bins)
|
y_data = pd.cut(y_data, bins=bins)
|
||||||
labels = [str(v) for v in y_data.sort_values().unique()]
|
labels = [str(v) for v in y_data.sort_values().unique()]
|
||||||
y_data = y_data.cat.codes
|
y_data = y_data.cat.codes
|
||||||
|
elif task == "density":
|
||||||
|
y_data = data.loc[X_data.index, "darts_rts_density"]
|
||||||
|
n_categories = 5
|
||||||
|
bins = pd.qcut(y_data, q=n_categories, duplicates="drop").unique().categories
|
||||||
|
# Change the first interval to start at 0
|
||||||
|
bins = pd.IntervalIndex.from_tuples([(0.0, interval.right) for interval in bins])
|
||||||
|
y_data = pd.cut(y_data, bins=bins)
|
||||||
|
labels = [str(v) for v in y_data.sort_values().unique()]
|
||||||
|
y_data = y_data.cat.codes
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown task: {task}")
|
||||||
return data, X_data, y_data, labels
|
return data, X_data, y_data, labels
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -71,7 +85,8 @@ def random_cv(
|
||||||
level: int,
|
level: int,
|
||||||
n_iter: int = 2000,
|
n_iter: int = 2000,
|
||||||
robust: bool = False,
|
robust: bool = False,
|
||||||
task: Literal["binary", "multi"] = "binary",
|
task: Literal["binary", "count", "density"] = "binary",
|
||||||
|
model: Literal["espa", "xgboost", "rf", "knn"] = "espa",
|
||||||
):
|
):
|
||||||
"""Perform random cross-validation on the training dataset.
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
|
|
@ -80,38 +95,74 @@ def random_cv(
|
||||||
level (int): The grid level to use.
|
level (int): The grid level to use.
|
||||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||||
robust (bool, optional): Whether to use robust training. Defaults to False.
|
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||||
task (Literal["binary", "multi"], optional): The classification task type. Defaults to "binary".
|
task (Literal["binary", "count", "density"], optional): The classification task type. Defaults to "binary".
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
|
_, X_data, y_data, labels = create_xy_data(grid=grid, level=level, task=task)
|
||||||
print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
|
print(f"Using {task}-class classification with {len(labels)} classes: {labels}")
|
||||||
print(f"{y_data.describe()=}")
|
print(f"{y_data.describe()=}")
|
||||||
X = X_data.to_numpy(dtype="float32")
|
X = X_data.to_numpy(dtype="float64")
|
||||||
y = y_data.to_numpy(dtype="int8")
|
y = y_data.to_numpy(dtype="int8")
|
||||||
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
||||||
print(f"{X.shape=}, {y.shape=}")
|
print(f"{X.shape=}, {y.shape=}")
|
||||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||||
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
||||||
|
|
||||||
|
if model == "espa":
|
||||||
|
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
||||||
|
if task == "binary":
|
||||||
param_grid = {
|
param_grid = {
|
||||||
"eps_cl": loguniform(1e-3, 1e7),
|
"eps_cl": loguniform(1e-4, 1e1),
|
||||||
"eps_e": loguniform(1e-3, 1e7),
|
"eps_e": loguniform(1e1, 1e7),
|
||||||
"initial_K": randint(20, 400),
|
"initial_K": randint(20, 400),
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42, robust=robust)
|
param_grid = {
|
||||||
|
"eps_cl": loguniform(1e-11, 1e-6),
|
||||||
|
"eps_e": loguniform(1e4, 1e8),
|
||||||
|
"initial_K": randint(400, 800),
|
||||||
|
}
|
||||||
|
elif model == "xgboost":
|
||||||
|
param_grid = {
|
||||||
|
"learning_rate": loguniform(1e-4, 1e-1),
|
||||||
|
"max_depth": randint(3, 15),
|
||||||
|
"n_estimators": randint(100, 1000),
|
||||||
|
"subsample": loguniform(0.5, 1.0),
|
||||||
|
"colsample_bytree": loguniform(0.5, 1.0),
|
||||||
|
}
|
||||||
|
clf = XGBClassifier(
|
||||||
|
objective="multi:softprob" if task != "binary" else "binary:logistic",
|
||||||
|
eval_metric="mlogloss" if task != "binary" else "logloss",
|
||||||
|
random_state=42,
|
||||||
|
tree_method="gpu_hist",
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
elif model == "rf":
|
||||||
|
param_grid = {
|
||||||
|
"max_depth": randint(5, 50),
|
||||||
|
"n_estimators": randint(50, 500),
|
||||||
|
}
|
||||||
|
clf = RandomForestClassifier(random_state=42)
|
||||||
|
elif model == "knn":
|
||||||
|
param_grid = {
|
||||||
|
"n_neighbors": randint(3, 15),
|
||||||
|
"weights": ["uniform", "distance"],
|
||||||
|
"algorithm": ["brute", "kd_tree", "ball_tree"],
|
||||||
|
}
|
||||||
|
clf = KNeighborsClassifier(random_state=42)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model: {model}")
|
||||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||||
if task == "binary":
|
if task == "binary":
|
||||||
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
||||||
else:
|
else:
|
||||||
metrics = [
|
metrics = [
|
||||||
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro",
|
"accuracy", # equals "f1_micro", "precision_micro", "recall_micro", "recall_weighted"
|
||||||
"f1_macro",
|
"f1_macro",
|
||||||
"f1_weighted",
|
"f1_weighted",
|
||||||
"precision_macro",
|
"precision_macro",
|
||||||
"precision_weighted",
|
"precision_weighted",
|
||||||
"recall_macro",
|
"recall_macro",
|
||||||
"recall_weighted",
|
|
||||||
"jaccard_micro",
|
"jaccard_micro",
|
||||||
"jaccard_macro",
|
"jaccard_macro",
|
||||||
"jaccard_weighted",
|
"jaccard_weighted",
|
||||||
|
|
@ -120,17 +171,17 @@ def random_cv(
|
||||||
clf,
|
clf,
|
||||||
param_grid,
|
param_grid,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
n_jobs=16,
|
n_jobs=8,
|
||||||
cv=cv,
|
cv=cv,
|
||||||
random_state=42,
|
random_state=42,
|
||||||
verbose=3,
|
verbose=10,
|
||||||
scoring=metrics,
|
scoring=metrics,
|
||||||
refit="f1" if task == "binary" else "f1_weighted",
|
refit="f1" if task == "binary" else "f1_weighted",
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
||||||
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
||||||
search.fit(X_train, y_train, max_iter=100)
|
search.fit(X_train, y_train, max_iter=300)
|
||||||
|
|
||||||
print("Best parameters combination found:")
|
print("Best parameters combination found:")
|
||||||
best_parameters = search.best_estimator_.get_params()
|
best_parameters = search.best_estimator_.get_params()
|
||||||
|
|
@ -144,29 +195,33 @@ def random_cv(
|
||||||
results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
|
results_dir = get_cv_results_dir("random_search", grid=grid, level=level, task=task)
|
||||||
|
|
||||||
# Store the search settings
|
# Store the search settings
|
||||||
|
# First convert the param_grid distributions to a serializable format
|
||||||
|
param_grid_serializable = {}
|
||||||
|
for key, dist in param_grid.items():
|
||||||
|
if isinstance(dist, loguniform):
|
||||||
|
param_grid_serializable[key] = {
|
||||||
|
"distribution": "loguniform",
|
||||||
|
"low": dist.a,
|
||||||
|
"high": dist.b,
|
||||||
|
}
|
||||||
|
elif isinstance(dist, randint):
|
||||||
|
param_grid_serializable[key] = {
|
||||||
|
"distribution": "randint",
|
||||||
|
"low": dist.a,
|
||||||
|
"high": dist.b,
|
||||||
|
}
|
||||||
|
elif isinstance(dist, list):
|
||||||
|
param_grid_serializable[key] = dist
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown distribution type for {key}: {type(dist)}")
|
||||||
settings = {
|
settings = {
|
||||||
"task": task,
|
"task": task,
|
||||||
|
"model": model,
|
||||||
"grid": grid,
|
"grid": grid,
|
||||||
"level": level,
|
"level": level,
|
||||||
"random_state": 42,
|
"random_state": 42,
|
||||||
"n_iter": n_iter,
|
"n_iter": n_iter,
|
||||||
"param_grid": {
|
"param_grid": param_grid_serializable,
|
||||||
"eps_cl": {
|
|
||||||
"distribution": "loguniform",
|
|
||||||
"low": param_grid["eps_cl"].a,
|
|
||||||
"high": param_grid["eps_cl"].b,
|
|
||||||
},
|
|
||||||
"eps_e": {
|
|
||||||
"distribution": "loguniform",
|
|
||||||
"low": param_grid["eps_e"].a,
|
|
||||||
"high": param_grid["eps_e"].b,
|
|
||||||
},
|
|
||||||
"initial_K": {
|
|
||||||
"distribution": "randint",
|
|
||||||
"low": param_grid["initial_K"].a,
|
|
||||||
"high": param_grid["initial_K"].b,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"cv_splits": cv.get_n_splits(),
|
"cv_splits": cv.get_n_splits(),
|
||||||
"metrics": metrics,
|
"metrics": metrics,
|
||||||
"classes": labels,
|
"classes": labels,
|
||||||
|
|
@ -195,6 +250,7 @@ def random_cv(
|
||||||
results.to_parquet(results_file)
|
results.to_parquet(results_file)
|
||||||
|
|
||||||
# Get the inner state of the best estimator
|
# Get the inner state of the best estimator
|
||||||
|
if model == "espa":
|
||||||
best_estimator = search.best_estimator_
|
best_estimator = search.best_estimator_
|
||||||
# Annotate the state with xarray metadata
|
# Annotate the state with xarray metadata
|
||||||
features = X_data.columns.tolist()
|
features = X_data.columns.tolist()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue