Add arcticdem

This commit is contained in:
Tobias Hölzer 2025-11-22 18:56:34 +01:00
parent d5b35d6da4
commit 1a71883999
18 changed files with 6981 additions and 2068 deletions

View file

@ -1 +1 @@
3.12 3.13

8289
pixi.lock generated

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,7 @@ version = "0.1.0"
description = "Add your description here" description = "Add your description here"
readme = "README.md" readme = "README.md"
authors = [{ name = "Tobias Hölzer", email = "tobiashoelzer@hotmail.com" }] authors = [{ name = "Tobias Hölzer", email = "tobiashoelzer@hotmail.com" }]
# requires-python = ">=3.10,<3.13" requires-python = ">=3.13,<3.14"
dependencies = [ dependencies = [
"aiohttp>=3.12.11", "aiohttp>=3.12.11",
"bokeh>=3.7.3", "bokeh>=3.7.3",
@ -28,19 +28,19 @@ dependencies = [
"mapclassify>=2.9.0", "mapclassify>=2.9.0",
"matplotlib>=3.10.3", "matplotlib>=3.10.3",
"netcdf4>=1.7.2", "netcdf4>=1.7.2",
"numba>=0.62.1", "numba>=0.61.2",
"numbagg>=0.9.3", "numbagg>=0.9.3",
"numpy>=2.3.0", "numpy>=2.2.6",
"odc-geo[all]>=0.4.10", "odc-geo[all]>=0.4.10",
"opt-einsum>=3.4.0", "opt-einsum>=3.4.0",
"pyarrow>=20.0.0", "pyarrow>=18.1.0",
"rechunker>=0.5.2", "rechunker>=0.5.2",
"requests>=2.32.3", "requests>=2.32.3",
"rich>=14.0.0", "rich>=14.0.0",
"rioxarray>=0.19.0", "rioxarray>=0.19.0",
"scipy>=1.15.3", "scipy>=1.15.3",
"seaborn>=0.13.2", "seaborn>=0.13.2",
"smart-geocubes[gee,dask,stac,viz]>=0.0.9", "smart-geocubes[stac]>=0.1.0",
"stopuhr>=0.0.10", "stopuhr>=0.0.10",
"ultraplot>=1.63.0", "ultraplot>=1.63.0",
"xanimate", "xanimate",
@ -51,7 +51,13 @@ dependencies = [
"geocube>=0.7.1,<0.8", "geocube>=0.7.1,<0.8",
"streamlit>=1.50.0,<2", "streamlit>=1.50.0,<2",
"altair[all]>=5.5.0,<6", "altair[all]>=5.5.0,<6",
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", "st-theme>=1.2.3,<2", "h5netcdf>=1.7.3,<2",
"streamlit-folium>=0.25.3,<0.26",
"st-theme>=1.2.3,<2",
"xgboost>=3.1.1,<4",
"s3fs>=2025.10.0,<2026",
"xarray-spatial",
"cupy-xarray>=0.1.4,<0.2",
] ]
[project.scripts] [project.scripts]
@ -59,6 +65,7 @@ create-grid = "entropice.grids:main"
darts = "entropice.darts:main" darts = "entropice.darts:main"
alpha-earth = "entropice.alphaearth:main" alpha-earth = "entropice.alphaearth:main"
era5 = "entropice.era5:cli" era5 = "entropice.era5:cli"
arcticdem = "entropice.arcticdem:cli"
train = "entropice.training:main" train = "entropice.training:main"
dataset = "entropice.dataset:main" dataset = "entropice.dataset:main"
@ -69,10 +76,20 @@ build-backend = "hatchling.build"
[tool.uv] [tool.uv]
package = true package = true
[[tool.uv.index]]
name = "nvidia"
url = "https://pypi.nvidia.com"
explicit = true
[tool.uv.sources] [tool.uv.sources]
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" } entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" } entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
xanimate = { git = "https://github.com/davbyr/xAnimate" } xanimate = { git = "https://github.com/davbyr/xAnimate" }
xdem = { git = "https://github.com/GlacioHack/xdem" }
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
cudf-cu12 = { index = "nvidia" }
cuml-cu12 = { index = "nvidia" }
cuspatial-cu12 = { index = "nvidia" }
[tool.ruff.lint.pyflakes] [tool.ruff.lint.pyflakes]
# Ignore libraries when checking for unused imports # Ignore libraries when checking for unused imports
@ -87,7 +104,8 @@ allowed-unused-imports = [
] ]
[tool.pixi.workspace] [tool.pixi.workspace]
channels = ["conda-forge"] channels = ["nvidia", "conda-forge", "rapidsai"]
channel-priority = "disabled"
platforms = ["linux-64"] platforms = ["linux-64"]
[tool.pixi.activation.env] [tool.pixi.activation.env]
@ -116,4 +134,5 @@ cupy = ">=13.6.0,<14"
nccl = ">=2.27.7.1,<3" nccl = ">=2.27.7.1,<3"
cudnn = ">=9.13.1.26,<10" cudnn = ">=9.13.1.26,<10"
cusparselt = ">=0.8.1.1,<0.9" cusparselt = ">=0.8.1.1,<0.9"
cuda-version = "12.1.*" cuda-version = "12.9.*"
rapids = ">=25.10.0,<26"

11
scripts/00grids.sh Normal file
View file

@ -0,0 +1,11 @@
#! /bin/bash
pixi run create-grid --grid hex --level 3
pixi run create-grid --grid hex --level 4
pixi run create-grid --grid hex --level 5
pixi run create-grid --grid hex --level 6
pixi run create-grid --grid healpix --level 6
pixi run create-grid --grid healpix --level 7
pixi run create-grid --grid healpix --level 8
pixi run create-grid --grid healpix --level 9
pixi run create-grid --grid healpix --level 10

11
scripts/01darts.sh Normal file
View file

@ -0,0 +1,11 @@
#! /bin/bash
pixi run darts --grid hex --level 3
pixi run darts --grid hex --level 4
pixi run darts --grid hex --level 5
pixi run darts --grid hex --level 6
pixi run darts --grid healpix --level 6
pixi run darts --grid healpix --level 7
pixi run darts --grid healpix --level 8
pixi run darts --grid healpix --level 9
pixi run darts --grid healpix --level 10

23
scripts/02alphaearth.sh Normal file
View file

@ -0,0 +1,23 @@
#!/bin/bash
pixi run alpha-earth download --grid hex --level 3
pixi run alpha-earth download --grid hex --level 4
pixi run alpha-earth download --grid hex --level 5
pixi run alpha-earth download --grid healpix --level 6
pixi run alpha-earth download --grid healpix --level 7
pixi run alpha-earth download --grid healpix --level 8
pixi run alpha-earth download --grid healpix --level 9
pixi run alpha-earth combine-to-zarr --grid hex --level 3
pixi run alpha-earth combine-to-zarr --grid hex --level 4
pixi run alpha-earth combine-to-zarr --grid hex --level 5
pixi run alpha-earth combine-to-zarr --grid healpix --level 6
pixi run alpha-earth combine-to-zarr --grid healpix --level 7
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
pixi run alpha-earth download --grid hex --level 6
pixi run alpha-earth download --grid healpix --level 10
pixi run alpha-earth combine-to-zarr --grid hex --level 6
pixi run alpha-earth combine-to-zarr --grid healpix --level 10

15
scripts/03era5.sh Normal file
View file

@ -0,0 +1,15 @@
#!/bin/bash
# pixi run era5 download
# pixi run era5 enrich
pixi run era5 spatial-agg --grid hex --level 3
pixi run era5 spatial-agg --grid hex --level 4
pixi run era5 spatial-agg --grid hex --level 5
pixi run era5 spatial-agg --grid hex --level 6
pixi run era5 spatial-agg --grid healpix --level 6
pixi run era5 spatial-agg --grid healpix --level 7
pixi run era5 spatial-agg --grid healpix --level 8
pixi run era5 spatial-agg --grid healpix --level 9
pixi run era5 spatial-agg --grid healpix --level 10

View file

@ -1,17 +0,0 @@
#!/bin/bash
# uv run alpha-earth download --grid hex --level 3
# uv run alpha-earth download --grid hex --level 4
# uv run alpha-earth download --grid hex --level 5
# uv run alpha-earth download --grid healpix --level 6
# uv run alpha-earth download --grid healpix --level 7
# uv run alpha-earth download --grid healpix --level 8
# uv run alpha-earth download --grid healpix --level 9
uv run alpha-earth combine-to-zarr --grid hex --level 3
uv run alpha-earth combine-to-zarr --grid hex --level 4
uv run alpha-earth combine-to-zarr --grid hex --level 5
uv run alpha-earth combine-to-zarr --grid healpix --level 6
uv run alpha-earth combine-to-zarr --grid healpix --level 7
uv run alpha-earth combine-to-zarr --grid healpix --level 8
uv run alpha-earth combine-to-zarr --grid healpix --level 9

View file

@ -1,66 +0,0 @@
"""Download ERA5 data from the Copernicus Data Store.
Web platform: https://cds.climate.copernicus.eu
"""
import os
import re
from pathlib import Path
import cdsapi
import cyclopts
from rich import pretty, print, traceback
traceback.install()
pretty.install()
DATA_DIR = Path(os.environ.get("DATA_DIR", "../../data")) / "entropyc-rts"
def hourly(years: str):
"""Download ERA5 data from the Copernicus Data Store.
Args:
years (str): Years to download, seperated by a '-'.
"""
assert re.compile(r"^\d{4}-\d{4}$").match(years), "Years must be in the format 'YYYY-YYYY'"
start_year, end_year = map(int, years.split("-"))
assert 1950 <= start_year <= end_year <= 2024, "Years must be between 1950 and 2024"
dataset = "reanalysis-era5-single-levels"
client = cdsapi.Client(wait_until_complete=False)
outdir = (DATA_DIR / "era5/cds").resolve()
outdir.mkdir(parents=True, exist_ok=True)
print(f"Downloading ERA5 data from {start_year} to {end_year}...")
for y in range(start_year, end_year + 1):
for month in [f"{i:02d}" for i in range(1, 13)]:
request = {
"product_type": ["reanalysis"],
"variable": [
"2m_temperature",
"total_precipitation",
"snow_depth",
"snow_density",
"snowfall",
"lake_ice_temperature",
"surface_sensible_heat_flux",
],
"year": [str(y)],
"month": [month],
"day": [f"{i:02d}" for i in range(1, 32)],
"time": [f"{i:02d}:00" for i in range(0, 24)],
"data_format": "netcdf",
"download_format": "unarchived",
"area": [85, -180, 50, 180],
}
outpath = outdir / f"era5_{y}_{month}.zip"
client.retrieve(dataset, request).download(str(outpath))
print(f"Downloaded {dataset} for {y}-{month}")
if __name__ == "__main__":
cyclopts.run(hourly)

View file

@ -1,18 +0,0 @@
#!/bin/bash
# uv run era5 download
# uv run era5 enrich
# Can be summer, winter or yearly
agg=$1
echo "Running ERA5 spatial aggregation for aggregation type: $agg"
uv run era5 spatial-agg --grid hex --level 3 --agg $agg
uv run era5 spatial-agg --grid hex --level 4 --agg $agg
uv run era5 spatial-agg --grid hex --level 5 --agg $agg
uv run era5 spatial-agg --grid healpix --level 6 --agg $agg
uv run era5 spatial-agg --grid healpix --level 7 --agg $agg
uv run era5 spatial-agg --grid healpix --level 8 --agg $agg
uv run era5 spatial-agg --grid healpix --level 9 --agg $agg

View file

@ -1,9 +0,0 @@
#!/bin/bash
uv run rts --grid hex --level 3
uv run rts --grid hex --level 4
uv run rts --grid hex --level 5
uv run rts --grid healpix --level 6
uv run rts --grid healpix --level 7
uv run rts --grid healpix --level 8
uv run rts --grid healpix --level 9

View file

@ -30,6 +30,8 @@ ee.Initialize(project="ee-tobias-hoelzer")
cli = cyclopts.App(name="alpha-earth") cli = cyclopts.App(name="alpha-earth")
# 7454521782,230147807,10000000.
@cli.command() @cli.command()
def download(grid: Literal["hex", "healpix"], level: int): def download(grid: Literal["hex", "healpix"], level: int):
@ -60,6 +62,8 @@ def download(grid: Literal["hex", "healpix"], level: int):
.combine(ee.Reducer.mean(), sharedInputs=True) .combine(ee.Reducer.mean(), sharedInputs=True)
.combine(ee.Reducer.percentile([1, 5, 25, 75, 95, 99]), sharedInputs=True), .combine(ee.Reducer.percentile([1, 5, 25, 75, 95, 99]), sharedInputs=True),
geometry=geom, geometry=geom,
scale=10,
bestEffort=True,
) )
# Add mean embedding values as properties to the feature # Add mean embedding values as properties to the feature
return feature.set(mean_dict) return feature.set(mean_dict)
@ -85,7 +89,7 @@ def download(grid: Literal["hex", "healpix"], level: int):
# Combine all batch results # Combine all batch results
df = pd.concat(all_results, ignore_index=True) df = pd.concat(all_results, ignore_index=True)
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left") embeddings_on_grid = grid_gdf.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
embeddings_file = get_annual_embeddings_file(grid, level, year) embeddings_file = get_annual_embeddings_file(grid, level, year)
embeddings_on_grid.to_parquet(embeddings_file) embeddings_on_grid.to_parquet(embeddings_file)
print(f"Saved embeddings for year {year} to {embeddings_file}.") print(f"Saved embeddings for year {year} to {embeddings_file}.")

308
src/entropice/arcticdem.py Normal file
View file

@ -0,0 +1,308 @@
import datetime
from dataclasses import dataclass
from math import ceil
import cupy as cp
import cupy_xarray
import cupyx.scipy.signal
import cyclopts
import dask.array
import dask.distributed as dd
import icechunk.xarray
import numpy as np
import smart_geocubes
import xarray as xr
import xrspatial
import zarr
from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt
from rich import pretty, print, traceback
from stopuhr import stopwatch
from xrspatial.aspect import _run_cupy as aspect_cupy
from xrspatial.curvature import _run_cupy as curvature_cupy
from xrspatial.slope import _run_cupy as slope_cupy
from zarr.codecs import BloscCodec
from entropice.paths import arcticdem_store
traceback.install(show_locals=True, suppress=[cyclopts])
pretty.install()
zarr.config.set({"async.concurrency": 128})
cli = cyclopts.App(name="arcticdem")
@cli.command()
def download():
adem = smart_geocubes.ArcticDEM32m(arcticdem_store)
adem.create(exists_ok=True)
with stopwatch("Download ArcticDEM data"):
adem.procedural_download(adem.extent.extent, None)
@dataclass
class KernelFactory:
res: int
size_px: int
@property
def size(self):
return self.size_px * self.res
@property
def inner(self):
return ceil(self.size / 3)
@staticmethod
def _to_cupy_f32(kernel):
return cp.asarray(kernel.astype("float32"))
def ring(self):
return self._to_cupy_f32(xrspatial.convolution.annulus_kernel(self.res, self.res, self.size, self.inner))
def tri(self):
kernel = np.ones((self.size_px, self.size_px), dtype=float)
kernel[self.size_px // 2, self.size_px // 2] = 0 # Set the center cell to 0
return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel))
def vrm(self):
kernel = np.ones((self.size_px, self.size_px), dtype=float) / self.size_px**2
return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel))
def tpi_cupy(chunk, kernels: KernelFactory):
kernel = kernels.ring()
kernel = kernel / cp.nansum(kernel)
tpi = chunk - cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
return tpi
def tri_cupy(chunk, kernels: KernelFactory):
kernel = kernels.tri()
c2 = chunk**2
focal_sum = cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
focal_sum_squared = cupyx.scipy.signal.convolve2d(c2, kernel, mode="same")
tri = np.sqrt((kernel.size - 1) * c2 - 2 * chunk * focal_sum + focal_sum_squared)
return tri
def ruggedness_cupy(chunk, slope, aspect, kernels: KernelFactory):
slope_rad = slope * (cp.pi / 180)
aspect_rad = aspect * (cp.pi / 180)
aspect_rad = cp.where(aspect_rad == -1, 0, aspect_rad)
xy = cp.sin(slope_rad)
z = cp.cos(slope_rad)
x = xy * cp.sin(aspect_rad)
y = xy * cp.cos(aspect_rad)
kernel = kernels.vrm()
# Calculate sums of x, y, and z components in the neighborhood
x_sum = cupyx.scipy.signal.convolve2d(x, kernel, mode="same")
y_sum = cupyx.scipy.signal.convolve2d(y, kernel, mode="same")
z_sum = cupyx.scipy.signal.convolve2d(z, kernel, mode="same")
# Calculate the resultant vector magnitude
vrm = 1 - np.sqrt(x_sum**2 + y_sum**2 + z_sum**2)
return vrm
def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> tuple[cp.array, cp.array]:
chunk_loc = block_info[0]["chunk-location"]
d = 15
cs = 3600
# Calculate safe slice bounds for edge chunks
y_start, y_end = max(0, cs * chunk_loc[0] - d), min(len(y), cs * chunk_loc[0] + cs + d)
x_start, x_end = max(0, cs * chunk_loc[1] - d), min(len(x), cs * chunk_loc[1] + cs + d)
# Extract coordinate arrays with safe bounds
y_chunk = cp.asarray(y[y_start:y_end])
x_chunk = cp.asarray(x[x_start:x_end])
# Handle cases where the extracted chunk doesn't match expected dimensions
if len(y_chunk) != chunk.shape[0] or len(x_chunk) != chunk.shape[1]:
print(
f"Adjusting coordinate chunk sizes: y_chunk {len(y_chunk)} vs chunk {chunk.shape[0]}, "
f"x_chunk {len(x_chunk)} vs chunk {chunk.shape[1]}"
)
# Pad coordinates to match chunk dimensions if needed
if len(y_chunk) < chunk.shape[0]:
# Pad with the edge values
pad_start = chunk.shape[0] - len(y_chunk)
y_chunk = cp.pad(y_chunk, (pad_start, 0), mode="edge")
if len(x_chunk) < chunk.shape[1]:
pad_start = chunk.shape[1] - len(x_chunk)
x_chunk = cp.pad(x_chunk, (pad_start, 0), mode="edge")
yy = cp.repeat(y_chunk.reshape(-1, 1), x_chunk.shape[0], axis=1)
xx = cp.repeat(x_chunk.reshape(1, -1), y_chunk.shape[0], axis=0)
return xx, yy
def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> np.array:
res = 32 # 32m resolution
small_kernels = KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
medium_kernels = KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
large_kernels = KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m)
# Check if there is data in the chunk
if np.all(np.isnan(chunk)):
# Return an array of NaNs with the expected shape
return np.full(
(12, chunk.shape[0] - 2 * large_kernels.size_px, chunk.shape[1] - 2 * large_kernels.size_px),
np.nan,
dtype=np.float32,
)
chunk = cp.asarray(chunk)
# Interpolate missing values in chunk (for patches smaller than 7x7 pixels)
mask = cp.isnan(chunk)
mask &= ~binary_dilation(binary_erosion(mask, iterations=3, brute_force=True), iterations=3, brute_force=True)
if cp.any(mask):
# Find indices of valid values
indices = distance_transform_edt(mask, return_distances=False, return_indices=True)
chunk = chunk[tuple(indices)]
# TPI
tpi_small = tpi_cupy(chunk, small_kernels)
tpi_medium = tpi_cupy(chunk, medium_kernels)
tpi_large = tpi_cupy(chunk, large_kernels)
# Slope
slope = slope_cupy(chunk, res, res)
# Aspect
aspect = aspect_cupy(chunk)
xx, yy = _get_xy_chunk(chunk, x, y, block_info)
aspect_correction = cp.arctan2(yy, xx) * (180 / cp.pi) + 90
aspect = (aspect + aspect_correction) % 360
# Curvature
curvature = curvature_cupy(chunk, res)
# TRI
tri_small = tri_cupy(chunk, small_kernels)
tri_medium = tri_cupy(chunk, medium_kernels)
tri_large = tri_cupy(chunk, large_kernels)
# Ruggedness
vrm_small = ruggedness_cupy(chunk, slope, aspect, small_kernels)
vrm_medium = ruggedness_cupy(chunk, slope, aspect, medium_kernels)
vrm_large = ruggedness_cupy(chunk, slope, aspect, large_kernels)
# Stack on GPU, then move to CPU once
res_gpu = cp.stack(
[
tpi_small,
tpi_medium,
tpi_large,
slope,
aspect,
curvature,
tri_small,
tri_medium,
tri_large,
vrm_small,
vrm_medium,
vrm_large,
],
axis=0,
)
# Trim edges on GPU
e = large_kernels.size_px
res_gpu = res_gpu[:, e:-e, e:-e]
# Single CPU transfer at the end
res = cp.asnumpy(res_gpu)
return res
def _enrich(terrain: xr.DataArray):
features = [
"tpi_small",
"tpi_medium",
"tpi_large",
"slope",
"aspect",
"curvature",
"tri_small",
"tri_medium",
"tri_large",
"vrm_small",
"vrm_medium",
"vrm_large",
]
output_chunks = (len(features), *tuple(terrain.data.chunks))
enriched_da = dask.array.map_overlap(
_enrich_chunk,
terrain.data,
# dask.array.from_array(terrain.y.data.reshape(-1, 1), chunks=(terrain.data.chunks[0][0], 3600)).repeat(
# terrain.x.size, axis=1
# ),
# dask.array.from_array(terrain.x.data.reshape(1, -1), chunks=(3600, terrain.data.chunks[1][0])).repeat(
# terrain.y.size, axis=0
# ),
x=terrain.x.to_numpy(),
y=terrain.y.to_numpy(),
depth=15, # large_kernels.size_px
chunks=output_chunks,
new_axis=0,
dtype=np.float32,
meta=np.array((), dtype=np.float32),
trim=False,
boundary=np.nan,
token="enrich_arcticdem",
)
enriched_da = xr.DataArray(
enriched_da,
dims=("feature", "y", "x"),
coords={
"feature": features,
"y": terrain.y,
"x": terrain.x,
},
)
return enriched_da, features
@cli.command()
def enrich():
with (
dd.LocalCluster(n_workers=7, threads_per_worker=2, memory_limit="30GB") as cluster,
dd.Client(cluster) as client,
):
print(client)
print(client.dashboard_link)
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
# Garbage collect from previous runs
accessor.repo.garbage_collect(datetime.datetime.now(datetime.UTC))
adem = accessor.open_xarray()
# session = adem.repo.readonly_session("main")
# adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref")
del adem.y.attrs["_FillValue"]
del adem.x.attrs["_FillValue"]
enriched, new_features = _enrich(adem.dem)
for feature in new_features:
adem[feature] = enriched.sel(feature=feature)
print(adem[new_features])
# subset = adem[new_features].isel(x=slice(190800, 220000), y=slice(61200, 100000))
encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features}
# subset.to_zarr("test2.zarr", mode="w", encoding=encodings)
session = accessor.repo.writable_session("main")
icechunk.xarray.to_icechunk(
adem[new_features],
session,
mode="a",
encoding=encodings,
)
session.commit("Add terrain features: TPI, slope, aspect, curvature, TRI, ruggedness")
print("Enrichment complete.")
if __name__ == "__main__":
cli()

View file

@ -71,6 +71,10 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
darts_counts = grid_gdf[darts_counts_columns] darts_counts = grid_gdf[darts_counts_columns]
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1) grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1)
darts_density_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_density")]
darts_density = grid_gdf[darts_density_columns]
grid_gdf["darts_rts_density"] = darts_density.dropna(axis=0, how="all").max(axis=1)
output_path = get_darts_rts_file(grid, level) output_path = get_darts_rts_file(grid, level)
grid_gdf.to_parquet(output_path) grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}") print(f"Saved RTS labels to {output_path}")

View file

@ -4,6 +4,7 @@ Author: Tobias Hölzer
Date: 09. June 2025 Date: 09. June 2025
""" """
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Literal from typing import Literal
import cartopy.crs as ccrs import cartopy.crs as ccrs
@ -16,8 +17,9 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import xarray as xr import xarray as xr
import xdggs import xdggs
import xvec # noqa: F401 import xvec
from rich import pretty, print, traceback from rich import pretty, print, traceback
from rich.progress import track
from shapely.geometry import Polygon from shapely.geometry import Polygon
from shapely.ops import transform from shapely.ops import transform
from stopuhr import stopwatch from stopuhr import stopwatch
@ -64,6 +66,24 @@ def get_cell_ids(grid: Literal["hex", "healpix"], level: int):
return cell_ids return cell_ids
def _get_cell_polygon(hex0_cell, resolution: int) -> tuple[list[Polygon], list[str], list[float]]:
hex_batch = []
hex_id_batch = []
hex_area_batch = []
hex_cells = h3.cell_to_children(hex0_cell, resolution)
for hex_id in hex_cells:
boundary_coords = h3.cell_to_boundary(hex_id)
hex_polygon = Polygon(boundary_coords)
hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat)
hex_area = h3.cell_area(hex_id, unit="km^2")
hex_batch.append(hex_polygon)
hex_id_batch.append(hex_id)
hex_area_batch.append(hex_area)
return hex_batch, hex_id_batch, hex_area_batch
@stopwatch("Create a global hex grid") @stopwatch("Create a global hex grid")
def create_global_hex_grid(resolution): def create_global_hex_grid(resolution):
"""Create a global hexagonal grid using H3. """Create a global hexagonal grid using H3.
@ -76,23 +96,37 @@ def create_global_hex_grid(resolution):
""" """
# Generate hexagons # Generate hexagons
# For resolutions >=5, use multiprocessing to speed up the process
hex0_cells = h3.get_res0_cells() hex0_cells = h3.get_res0_cells()
if resolution > 0:
hex_cells = []
for hex0_cell in hex0_cells:
hex_cells.extend(h3.cell_to_children(hex0_cell, resolution))
else:
hex_cells = hex0_cells
# Initialize lists to store hex information # Initialize lists to store hex information
hex_list = [] hex_list = []
hex_id_list = [] hex_id_list = []
hex_area_list = [] hex_area_list = []
if resolution >= 5:
with ProcessPoolExecutor(max_workers=20) as executor:
future_to_hex = {
executor.submit(_get_cell_polygon, hex0_cell, resolution): hex0_cell for hex0_cell in hex0_cells
}
for future in track(
as_completed(future_to_hex), description="Creating hex polygons...", total=len(hex0_cells)
):
hex_batch, hex_id_batch, hex_area_batch = future.result()
hex_list.extend(hex_batch)
hex_id_list.extend(hex_id_batch)
hex_area_list.extend(hex_area_batch)
else:
if resolution > 0:
hex_cells = []
for hex0_cell in track(hex0_cells, description="Generating hex ids...", total=len(hex0_cells)):
hex_cells.extend(h3.cell_to_children(hex0_cell, resolution))
else:
hex_cells = hex0_cells
# Convert each hex ID to a polygon # Convert each hex ID to a polygon
for hex_id in hex_cells: for hex_id in track(hex_cells, description="Creating hex polygons...", total=len(hex_cells)):
boundary_coords = h3.cell_to_boundary(hex_id) boundary_coords = h3.cell_to_boundary(hex_id)
hex_polygon = Polygon(boundary_coords) hex_polygon = Polygon(boundary_coords)
hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat) hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat)

View file

@ -36,7 +36,7 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
print(f"Predicting probabilities for {len(data)} cells...") print(f"Predicting probabilities for {len(data)} cells...")
# Predict in batches to avoid memory issues # Predict in batches to avoid memory issues
batch_size = 100_000 batch_size = 10_000
preds = [] preds = []
for i in range(0, len(data), batch_size): for i in range(0, len(data), batch_size):
batch = data.iloc[i : i + batch_size] batch = data.iloc[i : i + batch_size]
@ -45,7 +45,7 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
X_batch = batch.drop(columns=cols_to_drop).dropna() X_batch = batch.drop(columns=cols_to_drop).dropna()
cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy() cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy()
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy() cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
X_batch = X_batch.to_numpy(dtype="float32") X_batch = X_batch.to_numpy(dtype="float64")
X_batch = torch.asarray(X_batch, device=0) X_batch = torch.asarray(X_batch, device=0)
batch_preds = clf.predict(X_batch).cpu().numpy() batch_preds = clf.predict(X_batch).cpu().numpy()
batch_preds = gpd.GeoDataFrame( batch_preds = gpd.GeoDataFrame(

View file

@ -27,6 +27,8 @@ WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
TRAINING_DIR.mkdir(parents=True, exist_ok=True) TRAINING_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True) RESULTS_DIR.mkdir(parents=True, exist_ok=True)
arcticdem_store = DATA_DIR / "arcticdem32m.icechunk.zarr"
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp" watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet" dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"

View file

@ -111,7 +111,10 @@ def get_available_result_files() -> list[Path]:
if result_file.exists() and state_file.exists() and preds_file.exists() and settings_file.exists(): if result_file.exists() and state_file.exists() and preds_file.exists() and settings_file.exists():
result_files.append(search_dir) result_files.append(search_dir)
return sorted(result_files, reverse=True) # Most recent first def _key_func(path: Path):
return path.stat().st_mtime
return sorted(result_files, key=_key_func, reverse=True) # Most recent first
def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame: def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame: