Add arcticdem
This commit is contained in:
parent
d5b35d6da4
commit
1a71883999
18 changed files with 6981 additions and 2068 deletions
|
|
@ -1 +1 @@
|
|||
3.12
|
||||
3.13
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ version = "0.1.0"
|
|||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "Tobias Hölzer", email = "tobiashoelzer@hotmail.com" }]
|
||||
# requires-python = ">=3.10,<3.13"
|
||||
requires-python = ">=3.13,<3.14"
|
||||
dependencies = [
|
||||
"aiohttp>=3.12.11",
|
||||
"bokeh>=3.7.3",
|
||||
|
|
@ -28,19 +28,19 @@ dependencies = [
|
|||
"mapclassify>=2.9.0",
|
||||
"matplotlib>=3.10.3",
|
||||
"netcdf4>=1.7.2",
|
||||
"numba>=0.62.1",
|
||||
"numba>=0.61.2",
|
||||
"numbagg>=0.9.3",
|
||||
"numpy>=2.3.0",
|
||||
"numpy>=2.2.6",
|
||||
"odc-geo[all]>=0.4.10",
|
||||
"opt-einsum>=3.4.0",
|
||||
"pyarrow>=20.0.0",
|
||||
"pyarrow>=18.1.0",
|
||||
"rechunker>=0.5.2",
|
||||
"requests>=2.32.3",
|
||||
"rich>=14.0.0",
|
||||
"rioxarray>=0.19.0",
|
||||
"scipy>=1.15.3",
|
||||
"seaborn>=0.13.2",
|
||||
"smart-geocubes[gee,dask,stac,viz]>=0.0.9",
|
||||
"smart-geocubes[stac]>=0.1.0",
|
||||
"stopuhr>=0.0.10",
|
||||
"ultraplot>=1.63.0",
|
||||
"xanimate",
|
||||
|
|
@ -51,7 +51,13 @@ dependencies = [
|
|||
"geocube>=0.7.1,<0.8",
|
||||
"streamlit>=1.50.0,<2",
|
||||
"altair[all]>=5.5.0,<6",
|
||||
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", "st-theme>=1.2.3,<2",
|
||||
"h5netcdf>=1.7.3,<2",
|
||||
"streamlit-folium>=0.25.3,<0.26",
|
||||
"st-theme>=1.2.3,<2",
|
||||
"xgboost>=3.1.1,<4",
|
||||
"s3fs>=2025.10.0,<2026",
|
||||
"xarray-spatial",
|
||||
"cupy-xarray>=0.1.4,<0.2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
@ -59,6 +65,7 @@ create-grid = "entropice.grids:main"
|
|||
darts = "entropice.darts:main"
|
||||
alpha-earth = "entropice.alphaearth:main"
|
||||
era5 = "entropice.era5:cli"
|
||||
arcticdem = "entropice.arcticdem:cli"
|
||||
train = "entropice.training:main"
|
||||
dataset = "entropice.dataset:main"
|
||||
|
||||
|
|
@ -69,10 +76,20 @@ build-backend = "hatchling.build"
|
|||
[tool.uv]
|
||||
package = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "nvidia"
|
||||
url = "https://pypi.nvidia.com"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
|
||||
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
|
||||
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
||||
xdem = { git = "https://github.com/GlacioHack/xdem" }
|
||||
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
|
||||
cudf-cu12 = { index = "nvidia" }
|
||||
cuml-cu12 = { index = "nvidia" }
|
||||
cuspatial-cu12 = { index = "nvidia" }
|
||||
|
||||
[tool.ruff.lint.pyflakes]
|
||||
# Ignore libraries when checking for unused imports
|
||||
|
|
@ -87,7 +104,8 @@ allowed-unused-imports = [
|
|||
]
|
||||
|
||||
[tool.pixi.workspace]
|
||||
channels = ["conda-forge"]
|
||||
channels = ["nvidia", "conda-forge", "rapidsai"]
|
||||
channel-priority = "disabled"
|
||||
platforms = ["linux-64"]
|
||||
|
||||
[tool.pixi.activation.env]
|
||||
|
|
@ -116,4 +134,5 @@ cupy = ">=13.6.0,<14"
|
|||
nccl = ">=2.27.7.1,<3"
|
||||
cudnn = ">=9.13.1.26,<10"
|
||||
cusparselt = ">=0.8.1.1,<0.9"
|
||||
cuda-version = "12.1.*"
|
||||
cuda-version = "12.9.*"
|
||||
rapids = ">=25.10.0,<26"
|
||||
|
|
|
|||
11
scripts/00grids.sh
Normal file
11
scripts/00grids.sh
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#! /bin/bash
|
||||
|
||||
pixi run create-grid --grid hex --level 3
|
||||
pixi run create-grid --grid hex --level 4
|
||||
pixi run create-grid --grid hex --level 5
|
||||
pixi run create-grid --grid hex --level 6
|
||||
pixi run create-grid --grid healpix --level 6
|
||||
pixi run create-grid --grid healpix --level 7
|
||||
pixi run create-grid --grid healpix --level 8
|
||||
pixi run create-grid --grid healpix --level 9
|
||||
pixi run create-grid --grid healpix --level 10
|
||||
11
scripts/01darts.sh
Normal file
11
scripts/01darts.sh
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#! /bin/bash
|
||||
|
||||
pixi run darts --grid hex --level 3
|
||||
pixi run darts --grid hex --level 4
|
||||
pixi run darts --grid hex --level 5
|
||||
pixi run darts --grid hex --level 6
|
||||
pixi run darts --grid healpix --level 6
|
||||
pixi run darts --grid healpix --level 7
|
||||
pixi run darts --grid healpix --level 8
|
||||
pixi run darts --grid healpix --level 9
|
||||
pixi run darts --grid healpix --level 10
|
||||
23
scripts/02alphaearth.sh
Normal file
23
scripts/02alphaearth.sh
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
|
||||
pixi run alpha-earth download --grid hex --level 3
|
||||
pixi run alpha-earth download --grid hex --level 4
|
||||
pixi run alpha-earth download --grid hex --level 5
|
||||
pixi run alpha-earth download --grid healpix --level 6
|
||||
pixi run alpha-earth download --grid healpix --level 7
|
||||
pixi run alpha-earth download --grid healpix --level 8
|
||||
pixi run alpha-earth download --grid healpix --level 9
|
||||
|
||||
pixi run alpha-earth combine-to-zarr --grid hex --level 3
|
||||
pixi run alpha-earth combine-to-zarr --grid hex --level 4
|
||||
pixi run alpha-earth combine-to-zarr --grid hex --level 5
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 6
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 7
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
|
||||
|
||||
pixi run alpha-earth download --grid hex --level 6
|
||||
pixi run alpha-earth download --grid healpix --level 10
|
||||
|
||||
pixi run alpha-earth combine-to-zarr --grid hex --level 6
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 10
|
||||
15
scripts/03era5.sh
Normal file
15
scripts/03era5.sh
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
|
||||
# pixi run era5 download
|
||||
# pixi run era5 enrich
|
||||
|
||||
pixi run era5 spatial-agg --grid hex --level 3
|
||||
pixi run era5 spatial-agg --grid hex --level 4
|
||||
pixi run era5 spatial-agg --grid hex --level 5
|
||||
pixi run era5 spatial-agg --grid hex --level 6
|
||||
|
||||
pixi run era5 spatial-agg --grid healpix --level 6
|
||||
pixi run era5 spatial-agg --grid healpix --level 7
|
||||
pixi run era5 spatial-agg --grid healpix --level 8
|
||||
pixi run era5 spatial-agg --grid healpix --level 9
|
||||
pixi run era5 spatial-agg --grid healpix --level 10
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# uv run alpha-earth download --grid hex --level 3
|
||||
# uv run alpha-earth download --grid hex --level 4
|
||||
# uv run alpha-earth download --grid hex --level 5
|
||||
# uv run alpha-earth download --grid healpix --level 6
|
||||
# uv run alpha-earth download --grid healpix --level 7
|
||||
# uv run alpha-earth download --grid healpix --level 8
|
||||
# uv run alpha-earth download --grid healpix --level 9
|
||||
|
||||
uv run alpha-earth combine-to-zarr --grid hex --level 3
|
||||
uv run alpha-earth combine-to-zarr --grid hex --level 4
|
||||
uv run alpha-earth combine-to-zarr --grid hex --level 5
|
||||
uv run alpha-earth combine-to-zarr --grid healpix --level 6
|
||||
uv run alpha-earth combine-to-zarr --grid healpix --level 7
|
||||
uv run alpha-earth combine-to-zarr --grid healpix --level 8
|
||||
uv run alpha-earth combine-to-zarr --grid healpix --level 9
|
||||
|
|
@ -1,66 +0,0 @@
|
|||
"""Download ERA5 data from the Copernicus Data Store.
|
||||
|
||||
Web platform: https://cds.climate.copernicus.eu
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import cdsapi
|
||||
import cyclopts
|
||||
from rich import pretty, print, traceback
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "../../data")) / "entropyc-rts"
|
||||
|
||||
|
||||
def hourly(years: str):
|
||||
"""Download ERA5 data from the Copernicus Data Store.
|
||||
|
||||
Args:
|
||||
years (str): Years to download, seperated by a '-'.
|
||||
|
||||
"""
|
||||
assert re.compile(r"^\d{4}-\d{4}$").match(years), "Years must be in the format 'YYYY-YYYY'"
|
||||
start_year, end_year = map(int, years.split("-"))
|
||||
assert 1950 <= start_year <= end_year <= 2024, "Years must be between 1950 and 2024"
|
||||
|
||||
dataset = "reanalysis-era5-single-levels"
|
||||
client = cdsapi.Client(wait_until_complete=False)
|
||||
|
||||
outdir = (DATA_DIR / "era5/cds").resolve()
|
||||
outdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Downloading ERA5 data from {start_year} to {end_year}...")
|
||||
for y in range(start_year, end_year + 1):
|
||||
for month in [f"{i:02d}" for i in range(1, 13)]:
|
||||
request = {
|
||||
"product_type": ["reanalysis"],
|
||||
"variable": [
|
||||
"2m_temperature",
|
||||
"total_precipitation",
|
||||
"snow_depth",
|
||||
"snow_density",
|
||||
"snowfall",
|
||||
"lake_ice_temperature",
|
||||
"surface_sensible_heat_flux",
|
||||
],
|
||||
"year": [str(y)],
|
||||
"month": [month],
|
||||
"day": [f"{i:02d}" for i in range(1, 32)],
|
||||
"time": [f"{i:02d}:00" for i in range(0, 24)],
|
||||
"data_format": "netcdf",
|
||||
"download_format": "unarchived",
|
||||
"area": [85, -180, 50, 180],
|
||||
}
|
||||
|
||||
outpath = outdir / f"era5_{y}_{month}.zip"
|
||||
client.retrieve(dataset, request).download(str(outpath))
|
||||
print(f"Downloaded {dataset} for {y}-{month}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cyclopts.run(hourly)
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# uv run era5 download
|
||||
# uv run era5 enrich
|
||||
|
||||
# Can be summer, winter or yearly
|
||||
agg=$1
|
||||
|
||||
echo "Running ERA5 spatial aggregation for aggregation type: $agg"
|
||||
|
||||
uv run era5 spatial-agg --grid hex --level 3 --agg $agg
|
||||
uv run era5 spatial-agg --grid hex --level 4 --agg $agg
|
||||
uv run era5 spatial-agg --grid hex --level 5 --agg $agg
|
||||
|
||||
uv run era5 spatial-agg --grid healpix --level 6 --agg $agg
|
||||
uv run era5 spatial-agg --grid healpix --level 7 --agg $agg
|
||||
uv run era5 spatial-agg --grid healpix --level 8 --agg $agg
|
||||
uv run era5 spatial-agg --grid healpix --level 9 --agg $agg
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
uv run rts --grid hex --level 3
|
||||
uv run rts --grid hex --level 4
|
||||
uv run rts --grid hex --level 5
|
||||
uv run rts --grid healpix --level 6
|
||||
uv run rts --grid healpix --level 7
|
||||
uv run rts --grid healpix --level 8
|
||||
uv run rts --grid healpix --level 9
|
||||
|
|
@ -30,6 +30,8 @@ ee.Initialize(project="ee-tobias-hoelzer")
|
|||
|
||||
cli = cyclopts.App(name="alpha-earth")
|
||||
|
||||
# 7454521782,230147807,10000000.
|
||||
|
||||
|
||||
@cli.command()
|
||||
def download(grid: Literal["hex", "healpix"], level: int):
|
||||
|
|
@ -60,6 +62,8 @@ def download(grid: Literal["hex", "healpix"], level: int):
|
|||
.combine(ee.Reducer.mean(), sharedInputs=True)
|
||||
.combine(ee.Reducer.percentile([1, 5, 25, 75, 95, 99]), sharedInputs=True),
|
||||
geometry=geom,
|
||||
scale=10,
|
||||
bestEffort=True,
|
||||
)
|
||||
# Add mean embedding values as properties to the feature
|
||||
return feature.set(mean_dict)
|
||||
|
|
@ -85,7 +89,7 @@ def download(grid: Literal["hex", "healpix"], level: int):
|
|||
|
||||
# Combine all batch results
|
||||
df = pd.concat(all_results, ignore_index=True)
|
||||
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
|
||||
embeddings_on_grid = grid_gdf.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
|
||||
embeddings_file = get_annual_embeddings_file(grid, level, year)
|
||||
embeddings_on_grid.to_parquet(embeddings_file)
|
||||
print(f"Saved embeddings for year {year} to {embeddings_file}.")
|
||||
|
|
|
|||
308
src/entropice/arcticdem.py
Normal file
308
src/entropice/arcticdem.py
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
|
||||
import cupy as cp
|
||||
import cupy_xarray
|
||||
import cupyx.scipy.signal
|
||||
import cyclopts
|
||||
import dask.array
|
||||
import dask.distributed as dd
|
||||
import icechunk.xarray
|
||||
import numpy as np
|
||||
import smart_geocubes
|
||||
import xarray as xr
|
||||
import xrspatial
|
||||
import zarr
|
||||
from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt
|
||||
from rich import pretty, print, traceback
|
||||
from stopuhr import stopwatch
|
||||
from xrspatial.aspect import _run_cupy as aspect_cupy
|
||||
from xrspatial.curvature import _run_cupy as curvature_cupy
|
||||
from xrspatial.slope import _run_cupy as slope_cupy
|
||||
from zarr.codecs import BloscCodec
|
||||
|
||||
from entropice.paths import arcticdem_store
|
||||
|
||||
traceback.install(show_locals=True, suppress=[cyclopts])
|
||||
pretty.install()
|
||||
|
||||
zarr.config.set({"async.concurrency": 128})
|
||||
|
||||
cli = cyclopts.App(name="arcticdem")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def download():
|
||||
adem = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||
adem.create(exists_ok=True)
|
||||
with stopwatch("Download ArcticDEM data"):
|
||||
adem.procedural_download(adem.extent.extent, None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelFactory:
|
||||
res: int
|
||||
size_px: int
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return self.size_px * self.res
|
||||
|
||||
@property
|
||||
def inner(self):
|
||||
return ceil(self.size / 3)
|
||||
|
||||
@staticmethod
|
||||
def _to_cupy_f32(kernel):
|
||||
return cp.asarray(kernel.astype("float32"))
|
||||
|
||||
def ring(self):
|
||||
return self._to_cupy_f32(xrspatial.convolution.annulus_kernel(self.res, self.res, self.size, self.inner))
|
||||
|
||||
def tri(self):
|
||||
kernel = np.ones((self.size_px, self.size_px), dtype=float)
|
||||
kernel[self.size_px // 2, self.size_px // 2] = 0 # Set the center cell to 0
|
||||
return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel))
|
||||
|
||||
def vrm(self):
|
||||
kernel = np.ones((self.size_px, self.size_px), dtype=float) / self.size_px**2
|
||||
return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel))
|
||||
|
||||
|
||||
def tpi_cupy(chunk, kernels: KernelFactory):
|
||||
kernel = kernels.ring()
|
||||
kernel = kernel / cp.nansum(kernel)
|
||||
tpi = chunk - cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
|
||||
return tpi
|
||||
|
||||
|
||||
def tri_cupy(chunk, kernels: KernelFactory):
|
||||
kernel = kernels.tri()
|
||||
c2 = chunk**2
|
||||
focal_sum = cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
|
||||
focal_sum_squared = cupyx.scipy.signal.convolve2d(c2, kernel, mode="same")
|
||||
tri = np.sqrt((kernel.size - 1) * c2 - 2 * chunk * focal_sum + focal_sum_squared)
|
||||
return tri
|
||||
|
||||
|
||||
def ruggedness_cupy(chunk, slope, aspect, kernels: KernelFactory):
|
||||
slope_rad = slope * (cp.pi / 180)
|
||||
aspect_rad = aspect * (cp.pi / 180)
|
||||
aspect_rad = cp.where(aspect_rad == -1, 0, aspect_rad)
|
||||
|
||||
xy = cp.sin(slope_rad)
|
||||
z = cp.cos(slope_rad)
|
||||
x = xy * cp.sin(aspect_rad)
|
||||
y = xy * cp.cos(aspect_rad)
|
||||
|
||||
kernel = kernels.vrm()
|
||||
|
||||
# Calculate sums of x, y, and z components in the neighborhood
|
||||
x_sum = cupyx.scipy.signal.convolve2d(x, kernel, mode="same")
|
||||
y_sum = cupyx.scipy.signal.convolve2d(y, kernel, mode="same")
|
||||
z_sum = cupyx.scipy.signal.convolve2d(z, kernel, mode="same")
|
||||
|
||||
# Calculate the resultant vector magnitude
|
||||
vrm = 1 - np.sqrt(x_sum**2 + y_sum**2 + z_sum**2)
|
||||
return vrm
|
||||
|
||||
|
||||
def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> tuple[cp.array, cp.array]:
|
||||
chunk_loc = block_info[0]["chunk-location"]
|
||||
d = 15
|
||||
cs = 3600
|
||||
|
||||
# Calculate safe slice bounds for edge chunks
|
||||
y_start, y_end = max(0, cs * chunk_loc[0] - d), min(len(y), cs * chunk_loc[0] + cs + d)
|
||||
x_start, x_end = max(0, cs * chunk_loc[1] - d), min(len(x), cs * chunk_loc[1] + cs + d)
|
||||
|
||||
# Extract coordinate arrays with safe bounds
|
||||
y_chunk = cp.asarray(y[y_start:y_end])
|
||||
x_chunk = cp.asarray(x[x_start:x_end])
|
||||
|
||||
# Handle cases where the extracted chunk doesn't match expected dimensions
|
||||
if len(y_chunk) != chunk.shape[0] or len(x_chunk) != chunk.shape[1]:
|
||||
print(
|
||||
f"Adjusting coordinate chunk sizes: y_chunk {len(y_chunk)} vs chunk {chunk.shape[0]}, "
|
||||
f"x_chunk {len(x_chunk)} vs chunk {chunk.shape[1]}"
|
||||
)
|
||||
# Pad coordinates to match chunk dimensions if needed
|
||||
if len(y_chunk) < chunk.shape[0]:
|
||||
# Pad with the edge values
|
||||
pad_start = chunk.shape[0] - len(y_chunk)
|
||||
y_chunk = cp.pad(y_chunk, (pad_start, 0), mode="edge")
|
||||
if len(x_chunk) < chunk.shape[1]:
|
||||
pad_start = chunk.shape[1] - len(x_chunk)
|
||||
x_chunk = cp.pad(x_chunk, (pad_start, 0), mode="edge")
|
||||
|
||||
yy = cp.repeat(y_chunk.reshape(-1, 1), x_chunk.shape[0], axis=1)
|
||||
xx = cp.repeat(x_chunk.reshape(1, -1), y_chunk.shape[0], axis=0)
|
||||
return xx, yy
|
||||
|
||||
|
||||
def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> np.array:
|
||||
res = 32 # 32m resolution
|
||||
small_kernels = KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
|
||||
medium_kernels = KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
|
||||
large_kernels = KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m)
|
||||
|
||||
# Check if there is data in the chunk
|
||||
if np.all(np.isnan(chunk)):
|
||||
# Return an array of NaNs with the expected shape
|
||||
return np.full(
|
||||
(12, chunk.shape[0] - 2 * large_kernels.size_px, chunk.shape[1] - 2 * large_kernels.size_px),
|
||||
np.nan,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
chunk = cp.asarray(chunk)
|
||||
|
||||
# Interpolate missing values in chunk (for patches smaller than 7x7 pixels)
|
||||
mask = cp.isnan(chunk)
|
||||
mask &= ~binary_dilation(binary_erosion(mask, iterations=3, brute_force=True), iterations=3, brute_force=True)
|
||||
if cp.any(mask):
|
||||
# Find indices of valid values
|
||||
indices = distance_transform_edt(mask, return_distances=False, return_indices=True)
|
||||
chunk = chunk[tuple(indices)]
|
||||
|
||||
# TPI
|
||||
tpi_small = tpi_cupy(chunk, small_kernels)
|
||||
tpi_medium = tpi_cupy(chunk, medium_kernels)
|
||||
tpi_large = tpi_cupy(chunk, large_kernels)
|
||||
# Slope
|
||||
slope = slope_cupy(chunk, res, res)
|
||||
# Aspect
|
||||
aspect = aspect_cupy(chunk)
|
||||
xx, yy = _get_xy_chunk(chunk, x, y, block_info)
|
||||
aspect_correction = cp.arctan2(yy, xx) * (180 / cp.pi) + 90
|
||||
aspect = (aspect + aspect_correction) % 360
|
||||
# Curvature
|
||||
curvature = curvature_cupy(chunk, res)
|
||||
# TRI
|
||||
tri_small = tri_cupy(chunk, small_kernels)
|
||||
tri_medium = tri_cupy(chunk, medium_kernels)
|
||||
tri_large = tri_cupy(chunk, large_kernels)
|
||||
# Ruggedness
|
||||
vrm_small = ruggedness_cupy(chunk, slope, aspect, small_kernels)
|
||||
vrm_medium = ruggedness_cupy(chunk, slope, aspect, medium_kernels)
|
||||
vrm_large = ruggedness_cupy(chunk, slope, aspect, large_kernels)
|
||||
|
||||
# Stack on GPU, then move to CPU once
|
||||
res_gpu = cp.stack(
|
||||
[
|
||||
tpi_small,
|
||||
tpi_medium,
|
||||
tpi_large,
|
||||
slope,
|
||||
aspect,
|
||||
curvature,
|
||||
tri_small,
|
||||
tri_medium,
|
||||
tri_large,
|
||||
vrm_small,
|
||||
vrm_medium,
|
||||
vrm_large,
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
# Trim edges on GPU
|
||||
e = large_kernels.size_px
|
||||
res_gpu = res_gpu[:, e:-e, e:-e]
|
||||
|
||||
# Single CPU transfer at the end
|
||||
res = cp.asnumpy(res_gpu)
|
||||
return res
|
||||
|
||||
|
||||
def _enrich(terrain: xr.DataArray):
|
||||
features = [
|
||||
"tpi_small",
|
||||
"tpi_medium",
|
||||
"tpi_large",
|
||||
"slope",
|
||||
"aspect",
|
||||
"curvature",
|
||||
"tri_small",
|
||||
"tri_medium",
|
||||
"tri_large",
|
||||
"vrm_small",
|
||||
"vrm_medium",
|
||||
"vrm_large",
|
||||
]
|
||||
output_chunks = (len(features), *tuple(terrain.data.chunks))
|
||||
enriched_da = dask.array.map_overlap(
|
||||
_enrich_chunk,
|
||||
terrain.data,
|
||||
# dask.array.from_array(terrain.y.data.reshape(-1, 1), chunks=(terrain.data.chunks[0][0], 3600)).repeat(
|
||||
# terrain.x.size, axis=1
|
||||
# ),
|
||||
# dask.array.from_array(terrain.x.data.reshape(1, -1), chunks=(3600, terrain.data.chunks[1][0])).repeat(
|
||||
# terrain.y.size, axis=0
|
||||
# ),
|
||||
x=terrain.x.to_numpy(),
|
||||
y=terrain.y.to_numpy(),
|
||||
depth=15, # large_kernels.size_px
|
||||
chunks=output_chunks,
|
||||
new_axis=0,
|
||||
dtype=np.float32,
|
||||
meta=np.array((), dtype=np.float32),
|
||||
trim=False,
|
||||
boundary=np.nan,
|
||||
token="enrich_arcticdem",
|
||||
)
|
||||
enriched_da = xr.DataArray(
|
||||
enriched_da,
|
||||
dims=("feature", "y", "x"),
|
||||
coords={
|
||||
"feature": features,
|
||||
"y": terrain.y,
|
||||
"x": terrain.x,
|
||||
},
|
||||
)
|
||||
return enriched_da, features
|
||||
|
||||
|
||||
@cli.command()
|
||||
def enrich():
|
||||
with (
|
||||
dd.LocalCluster(n_workers=7, threads_per_worker=2, memory_limit="30GB") as cluster,
|
||||
dd.Client(cluster) as client,
|
||||
):
|
||||
print(client)
|
||||
print(client.dashboard_link)
|
||||
|
||||
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
|
||||
|
||||
# Garbage collect from previous runs
|
||||
accessor.repo.garbage_collect(datetime.datetime.now(datetime.UTC))
|
||||
|
||||
adem = accessor.open_xarray()
|
||||
# session = adem.repo.readonly_session("main")
|
||||
# adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref")
|
||||
del adem.y.attrs["_FillValue"]
|
||||
del adem.x.attrs["_FillValue"]
|
||||
enriched, new_features = _enrich(adem.dem)
|
||||
|
||||
for feature in new_features:
|
||||
adem[feature] = enriched.sel(feature=feature)
|
||||
print(adem[new_features])
|
||||
# subset = adem[new_features].isel(x=slice(190800, 220000), y=slice(61200, 100000))
|
||||
encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features}
|
||||
# subset.to_zarr("test2.zarr", mode="w", encoding=encodings)
|
||||
|
||||
session = accessor.repo.writable_session("main")
|
||||
icechunk.xarray.to_icechunk(
|
||||
adem[new_features],
|
||||
session,
|
||||
mode="a",
|
||||
encoding=encodings,
|
||||
)
|
||||
session.commit("Add terrain features: TPI, slope, aspect, curvature, TRI, ruggedness")
|
||||
|
||||
print("Enrichment complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
|
@ -71,6 +71,10 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
|||
darts_counts = grid_gdf[darts_counts_columns]
|
||||
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1)
|
||||
|
||||
darts_density_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_density")]
|
||||
darts_density = grid_gdf[darts_density_columns]
|
||||
grid_gdf["darts_rts_density"] = darts_density.dropna(axis=0, how="all").max(axis=1)
|
||||
|
||||
output_path = get_darts_rts_file(grid, level)
|
||||
grid_gdf.to_parquet(output_path)
|
||||
print(f"Saved RTS labels to {output_path}")
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Author: Tobias Hölzer
|
|||
Date: 09. June 2025
|
||||
"""
|
||||
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import Literal
|
||||
|
||||
import cartopy.crs as ccrs
|
||||
|
|
@ -16,8 +17,9 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import xarray as xr
|
||||
import xdggs
|
||||
import xvec # noqa: F401
|
||||
import xvec
|
||||
from rich import pretty, print, traceback
|
||||
from rich.progress import track
|
||||
from shapely.geometry import Polygon
|
||||
from shapely.ops import transform
|
||||
from stopuhr import stopwatch
|
||||
|
|
@ -64,6 +66,24 @@ def get_cell_ids(grid: Literal["hex", "healpix"], level: int):
|
|||
return cell_ids
|
||||
|
||||
|
||||
def _get_cell_polygon(hex0_cell, resolution: int) -> tuple[list[Polygon], list[str], list[float]]:
|
||||
hex_batch = []
|
||||
hex_id_batch = []
|
||||
hex_area_batch = []
|
||||
|
||||
hex_cells = h3.cell_to_children(hex0_cell, resolution)
|
||||
|
||||
for hex_id in hex_cells:
|
||||
boundary_coords = h3.cell_to_boundary(hex_id)
|
||||
hex_polygon = Polygon(boundary_coords)
|
||||
hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat)
|
||||
hex_area = h3.cell_area(hex_id, unit="km^2")
|
||||
hex_batch.append(hex_polygon)
|
||||
hex_id_batch.append(hex_id)
|
||||
hex_area_batch.append(hex_area)
|
||||
return hex_batch, hex_id_batch, hex_area_batch
|
||||
|
||||
|
||||
@stopwatch("Create a global hex grid")
|
||||
def create_global_hex_grid(resolution):
|
||||
"""Create a global hexagonal grid using H3.
|
||||
|
|
@ -76,30 +96,44 @@ def create_global_hex_grid(resolution):
|
|||
|
||||
"""
|
||||
# Generate hexagons
|
||||
# For resolutions >=5, use multiprocessing to speed up the process
|
||||
hex0_cells = h3.get_res0_cells()
|
||||
|
||||
if resolution > 0:
|
||||
hex_cells = []
|
||||
for hex0_cell in hex0_cells:
|
||||
hex_cells.extend(h3.cell_to_children(hex0_cell, resolution))
|
||||
|
||||
else:
|
||||
hex_cells = hex0_cells
|
||||
|
||||
# Initialize lists to store hex information
|
||||
hex_list = []
|
||||
hex_id_list = []
|
||||
hex_area_list = []
|
||||
|
||||
# Convert each hex ID to a polygon
|
||||
for hex_id in hex_cells:
|
||||
boundary_coords = h3.cell_to_boundary(hex_id)
|
||||
hex_polygon = Polygon(boundary_coords)
|
||||
hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat)
|
||||
hex_area = h3.cell_area(hex_id, unit="km^2")
|
||||
hex_list.append(hex_polygon)
|
||||
hex_id_list.append(hex_id)
|
||||
hex_area_list.append(hex_area)
|
||||
if resolution >= 5:
|
||||
with ProcessPoolExecutor(max_workers=20) as executor:
|
||||
future_to_hex = {
|
||||
executor.submit(_get_cell_polygon, hex0_cell, resolution): hex0_cell for hex0_cell in hex0_cells
|
||||
}
|
||||
for future in track(
|
||||
as_completed(future_to_hex), description="Creating hex polygons...", total=len(hex0_cells)
|
||||
):
|
||||
hex_batch, hex_id_batch, hex_area_batch = future.result()
|
||||
hex_list.extend(hex_batch)
|
||||
hex_id_list.extend(hex_id_batch)
|
||||
hex_area_list.extend(hex_area_batch)
|
||||
else:
|
||||
if resolution > 0:
|
||||
hex_cells = []
|
||||
for hex0_cell in track(hex0_cells, description="Generating hex ids...", total=len(hex0_cells)):
|
||||
hex_cells.extend(h3.cell_to_children(hex0_cell, resolution))
|
||||
|
||||
else:
|
||||
hex_cells = hex0_cells
|
||||
|
||||
# Convert each hex ID to a polygon
|
||||
for hex_id in track(hex_cells, description="Creating hex polygons...", total=len(hex_cells)):
|
||||
boundary_coords = h3.cell_to_boundary(hex_id)
|
||||
hex_polygon = Polygon(boundary_coords)
|
||||
hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat)
|
||||
hex_area = h3.cell_area(hex_id, unit="km^2")
|
||||
hex_list.append(hex_polygon)
|
||||
hex_id_list.append(hex_id)
|
||||
hex_area_list.append(hex_area)
|
||||
|
||||
# Create GeoDataFrame
|
||||
grid = gpd.GeoDataFrame({"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, crs="EPSG:4326")
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
|
|||
print(f"Predicting probabilities for {len(data)} cells...")
|
||||
|
||||
# Predict in batches to avoid memory issues
|
||||
batch_size = 100_000
|
||||
batch_size = 10_000
|
||||
preds = []
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data.iloc[i : i + batch_size]
|
||||
|
|
@ -45,7 +45,7 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
|
|||
X_batch = batch.drop(columns=cols_to_drop).dropna()
|
||||
cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy()
|
||||
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
|
||||
X_batch = X_batch.to_numpy(dtype="float32")
|
||||
X_batch = X_batch.to_numpy(dtype="float64")
|
||||
X_batch = torch.asarray(X_batch, device=0)
|
||||
batch_preds = clf.predict(X_batch).cpu().numpy()
|
||||
batch_preds = gpd.GeoDataFrame(
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
|
|||
TRAINING_DIR.mkdir(parents=True, exist_ok=True)
|
||||
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
arcticdem_store = DATA_DIR / "arcticdem32m.icechunk.zarr"
|
||||
|
||||
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
||||
|
||||
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
||||
|
|
|
|||
|
|
@ -111,7 +111,10 @@ def get_available_result_files() -> list[Path]:
|
|||
if result_file.exists() and state_file.exists() and preds_file.exists() and settings_file.exists():
|
||||
result_files.append(search_dir)
|
||||
|
||||
return sorted(result_files, reverse=True) # Most recent first
|
||||
def _key_func(path: Path):
|
||||
return path.stat().st_mtime
|
||||
|
||||
return sorted(result_files, key=_key_func, reverse=True) # Most recent first
|
||||
|
||||
|
||||
def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue