Add arcticdem

This commit is contained in:
Tobias Hölzer 2025-11-22 18:56:34 +01:00
parent d5b35d6da4
commit 1a71883999
18 changed files with 6981 additions and 2068 deletions

View file

@ -1 +1 @@
3.12
3.13

8289
pixi.lock generated

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,7 @@ version = "0.1.0"
description = "Add your description here"
readme = "README.md"
authors = [{ name = "Tobias Hölzer", email = "tobiashoelzer@hotmail.com" }]
# requires-python = ">=3.10,<3.13"
requires-python = ">=3.13,<3.14"
dependencies = [
"aiohttp>=3.12.11",
"bokeh>=3.7.3",
@ -28,19 +28,19 @@ dependencies = [
"mapclassify>=2.9.0",
"matplotlib>=3.10.3",
"netcdf4>=1.7.2",
"numba>=0.62.1",
"numba>=0.61.2",
"numbagg>=0.9.3",
"numpy>=2.3.0",
"numpy>=2.2.6",
"odc-geo[all]>=0.4.10",
"opt-einsum>=3.4.0",
"pyarrow>=20.0.0",
"pyarrow>=18.1.0",
"rechunker>=0.5.2",
"requests>=2.32.3",
"rich>=14.0.0",
"rioxarray>=0.19.0",
"scipy>=1.15.3",
"seaborn>=0.13.2",
"smart-geocubes[gee,dask,stac,viz]>=0.0.9",
"smart-geocubes[stac]>=0.1.0",
"stopuhr>=0.0.10",
"ultraplot>=1.63.0",
"xanimate",
@ -51,7 +51,13 @@ dependencies = [
"geocube>=0.7.1,<0.8",
"streamlit>=1.50.0,<2",
"altair[all]>=5.5.0,<6",
"h5netcdf>=1.7.3,<2", "streamlit-folium>=0.25.3,<0.26", "st-theme>=1.2.3,<2",
"h5netcdf>=1.7.3,<2",
"streamlit-folium>=0.25.3,<0.26",
"st-theme>=1.2.3,<2",
"xgboost>=3.1.1,<4",
"s3fs>=2025.10.0,<2026",
"xarray-spatial",
"cupy-xarray>=0.1.4,<0.2",
]
[project.scripts]
@ -59,6 +65,7 @@ create-grid = "entropice.grids:main"
darts = "entropice.darts:main"
alpha-earth = "entropice.alphaearth:main"
era5 = "entropice.era5:cli"
arcticdem = "entropice.arcticdem:cli"
train = "entropice.training:main"
dataset = "entropice.dataset:main"
@ -69,10 +76,20 @@ build-backend = "hatchling.build"
[tool.uv]
package = true
[[tool.uv.index]]
name = "nvidia"
url = "https://pypi.nvidia.com"
explicit = true
[tool.uv.sources]
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
xanimate = { git = "https://github.com/davbyr/xAnimate" }
xdem = { git = "https://github.com/GlacioHack/xdem" }
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
cudf-cu12 = { index = "nvidia" }
cuml-cu12 = { index = "nvidia" }
cuspatial-cu12 = { index = "nvidia" }
[tool.ruff.lint.pyflakes]
# Ignore libraries when checking for unused imports
@ -87,7 +104,8 @@ allowed-unused-imports = [
]
[tool.pixi.workspace]
channels = ["conda-forge"]
channels = ["nvidia", "conda-forge", "rapidsai"]
channel-priority = "disabled"
platforms = ["linux-64"]
[tool.pixi.activation.env]
@ -116,4 +134,5 @@ cupy = ">=13.6.0,<14"
nccl = ">=2.27.7.1,<3"
cudnn = ">=9.13.1.26,<10"
cusparselt = ">=0.8.1.1,<0.9"
cuda-version = "12.1.*"
cuda-version = "12.9.*"
rapids = ">=25.10.0,<26"

11
scripts/00grids.sh Normal file
View file

@ -0,0 +1,11 @@
#! /bin/bash
pixi run create-grid --grid hex --level 3
pixi run create-grid --grid hex --level 4
pixi run create-grid --grid hex --level 5
pixi run create-grid --grid hex --level 6
pixi run create-grid --grid healpix --level 6
pixi run create-grid --grid healpix --level 7
pixi run create-grid --grid healpix --level 8
pixi run create-grid --grid healpix --level 9
pixi run create-grid --grid healpix --level 10

11
scripts/01darts.sh Normal file
View file

@ -0,0 +1,11 @@
#! /bin/bash
pixi run darts --grid hex --level 3
pixi run darts --grid hex --level 4
pixi run darts --grid hex --level 5
pixi run darts --grid hex --level 6
pixi run darts --grid healpix --level 6
pixi run darts --grid healpix --level 7
pixi run darts --grid healpix --level 8
pixi run darts --grid healpix --level 9
pixi run darts --grid healpix --level 10

23
scripts/02alphaearth.sh Normal file
View file

@ -0,0 +1,23 @@
#!/bin/bash
pixi run alpha-earth download --grid hex --level 3
pixi run alpha-earth download --grid hex --level 4
pixi run alpha-earth download --grid hex --level 5
pixi run alpha-earth download --grid healpix --level 6
pixi run alpha-earth download --grid healpix --level 7
pixi run alpha-earth download --grid healpix --level 8
pixi run alpha-earth download --grid healpix --level 9
pixi run alpha-earth combine-to-zarr --grid hex --level 3
pixi run alpha-earth combine-to-zarr --grid hex --level 4
pixi run alpha-earth combine-to-zarr --grid hex --level 5
pixi run alpha-earth combine-to-zarr --grid healpix --level 6
pixi run alpha-earth combine-to-zarr --grid healpix --level 7
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
pixi run alpha-earth download --grid hex --level 6
pixi run alpha-earth download --grid healpix --level 10
pixi run alpha-earth combine-to-zarr --grid hex --level 6
pixi run alpha-earth combine-to-zarr --grid healpix --level 10

15
scripts/03era5.sh Normal file
View file

@ -0,0 +1,15 @@
#!/bin/bash
# pixi run era5 download
# pixi run era5 enrich
pixi run era5 spatial-agg --grid hex --level 3
pixi run era5 spatial-agg --grid hex --level 4
pixi run era5 spatial-agg --grid hex --level 5
pixi run era5 spatial-agg --grid hex --level 6
pixi run era5 spatial-agg --grid healpix --level 6
pixi run era5 spatial-agg --grid healpix --level 7
pixi run era5 spatial-agg --grid healpix --level 8
pixi run era5 spatial-agg --grid healpix --level 9
pixi run era5 spatial-agg --grid healpix --level 10

View file

@ -1,17 +0,0 @@
#!/bin/bash
# uv run alpha-earth download --grid hex --level 3
# uv run alpha-earth download --grid hex --level 4
# uv run alpha-earth download --grid hex --level 5
# uv run alpha-earth download --grid healpix --level 6
# uv run alpha-earth download --grid healpix --level 7
# uv run alpha-earth download --grid healpix --level 8
# uv run alpha-earth download --grid healpix --level 9
uv run alpha-earth combine-to-zarr --grid hex --level 3
uv run alpha-earth combine-to-zarr --grid hex --level 4
uv run alpha-earth combine-to-zarr --grid hex --level 5
uv run alpha-earth combine-to-zarr --grid healpix --level 6
uv run alpha-earth combine-to-zarr --grid healpix --level 7
uv run alpha-earth combine-to-zarr --grid healpix --level 8
uv run alpha-earth combine-to-zarr --grid healpix --level 9

View file

@ -1,66 +0,0 @@
"""Download ERA5 data from the Copernicus Data Store.
Web platform: https://cds.climate.copernicus.eu
"""
import os
import re
from pathlib import Path
import cdsapi
import cyclopts
from rich import pretty, print, traceback
traceback.install()
pretty.install()
DATA_DIR = Path(os.environ.get("DATA_DIR", "../../data")) / "entropyc-rts"
def hourly(years: str):
"""Download ERA5 data from the Copernicus Data Store.
Args:
years (str): Years to download, seperated by a '-'.
"""
assert re.compile(r"^\d{4}-\d{4}$").match(years), "Years must be in the format 'YYYY-YYYY'"
start_year, end_year = map(int, years.split("-"))
assert 1950 <= start_year <= end_year <= 2024, "Years must be between 1950 and 2024"
dataset = "reanalysis-era5-single-levels"
client = cdsapi.Client(wait_until_complete=False)
outdir = (DATA_DIR / "era5/cds").resolve()
outdir.mkdir(parents=True, exist_ok=True)
print(f"Downloading ERA5 data from {start_year} to {end_year}...")
for y in range(start_year, end_year + 1):
for month in [f"{i:02d}" for i in range(1, 13)]:
request = {
"product_type": ["reanalysis"],
"variable": [
"2m_temperature",
"total_precipitation",
"snow_depth",
"snow_density",
"snowfall",
"lake_ice_temperature",
"surface_sensible_heat_flux",
],
"year": [str(y)],
"month": [month],
"day": [f"{i:02d}" for i in range(1, 32)],
"time": [f"{i:02d}:00" for i in range(0, 24)],
"data_format": "netcdf",
"download_format": "unarchived",
"area": [85, -180, 50, 180],
}
outpath = outdir / f"era5_{y}_{month}.zip"
client.retrieve(dataset, request).download(str(outpath))
print(f"Downloaded {dataset} for {y}-{month}")
if __name__ == "__main__":
cyclopts.run(hourly)

View file

@ -1,18 +0,0 @@
#!/bin/bash
# uv run era5 download
# uv run era5 enrich
# Can be summer, winter or yearly
agg=$1
echo "Running ERA5 spatial aggregation for aggregation type: $agg"
uv run era5 spatial-agg --grid hex --level 3 --agg $agg
uv run era5 spatial-agg --grid hex --level 4 --agg $agg
uv run era5 spatial-agg --grid hex --level 5 --agg $agg
uv run era5 spatial-agg --grid healpix --level 6 --agg $agg
uv run era5 spatial-agg --grid healpix --level 7 --agg $agg
uv run era5 spatial-agg --grid healpix --level 8 --agg $agg
uv run era5 spatial-agg --grid healpix --level 9 --agg $agg

View file

@ -1,9 +0,0 @@
#!/bin/bash
uv run rts --grid hex --level 3
uv run rts --grid hex --level 4
uv run rts --grid hex --level 5
uv run rts --grid healpix --level 6
uv run rts --grid healpix --level 7
uv run rts --grid healpix --level 8
uv run rts --grid healpix --level 9

View file

@ -30,6 +30,8 @@ ee.Initialize(project="ee-tobias-hoelzer")
cli = cyclopts.App(name="alpha-earth")
# 7454521782,230147807,10000000.
@cli.command()
def download(grid: Literal["hex", "healpix"], level: int):
@ -60,6 +62,8 @@ def download(grid: Literal["hex", "healpix"], level: int):
.combine(ee.Reducer.mean(), sharedInputs=True)
.combine(ee.Reducer.percentile([1, 5, 25, 75, 95, 99]), sharedInputs=True),
geometry=geom,
scale=10,
bestEffort=True,
)
# Add mean embedding values as properties to the feature
return feature.set(mean_dict)
@ -85,7 +89,7 @@ def download(grid: Literal["hex", "healpix"], level: int):
# Combine all batch results
df = pd.concat(all_results, ignore_index=True)
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
embeddings_on_grid = grid_gdf.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
embeddings_file = get_annual_embeddings_file(grid, level, year)
embeddings_on_grid.to_parquet(embeddings_file)
print(f"Saved embeddings for year {year} to {embeddings_file}.")

308
src/entropice/arcticdem.py Normal file
View file

@ -0,0 +1,308 @@
import datetime
from dataclasses import dataclass
from math import ceil
import cupy as cp
import cupy_xarray
import cupyx.scipy.signal
import cyclopts
import dask.array
import dask.distributed as dd
import icechunk.xarray
import numpy as np
import smart_geocubes
import xarray as xr
import xrspatial
import zarr
from cupyx.scipy.ndimage import binary_dilation, binary_erosion, distance_transform_edt
from rich import pretty, print, traceback
from stopuhr import stopwatch
from xrspatial.aspect import _run_cupy as aspect_cupy
from xrspatial.curvature import _run_cupy as curvature_cupy
from xrspatial.slope import _run_cupy as slope_cupy
from zarr.codecs import BloscCodec
from entropice.paths import arcticdem_store
traceback.install(show_locals=True, suppress=[cyclopts])
pretty.install()
zarr.config.set({"async.concurrency": 128})
cli = cyclopts.App(name="arcticdem")
@cli.command()
def download():
adem = smart_geocubes.ArcticDEM32m(arcticdem_store)
adem.create(exists_ok=True)
with stopwatch("Download ArcticDEM data"):
adem.procedural_download(adem.extent.extent, None)
@dataclass
class KernelFactory:
res: int
size_px: int
@property
def size(self):
return self.size_px * self.res
@property
def inner(self):
return ceil(self.size / 3)
@staticmethod
def _to_cupy_f32(kernel):
return cp.asarray(kernel.astype("float32"))
def ring(self):
return self._to_cupy_f32(xrspatial.convolution.annulus_kernel(self.res, self.res, self.size, self.inner))
def tri(self):
kernel = np.ones((self.size_px, self.size_px), dtype=float)
kernel[self.size_px // 2, self.size_px // 2] = 0 # Set the center cell to 0
return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel))
def vrm(self):
kernel = np.ones((self.size_px, self.size_px), dtype=float) / self.size_px**2
return self._to_cupy_f32(xrspatial.convolution.custom_kernel(kernel))
def tpi_cupy(chunk, kernels: KernelFactory):
kernel = kernels.ring()
kernel = kernel / cp.nansum(kernel)
tpi = chunk - cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
return tpi
def tri_cupy(chunk, kernels: KernelFactory):
kernel = kernels.tri()
c2 = chunk**2
focal_sum = cupyx.scipy.signal.convolve2d(chunk, kernel, mode="same")
focal_sum_squared = cupyx.scipy.signal.convolve2d(c2, kernel, mode="same")
tri = np.sqrt((kernel.size - 1) * c2 - 2 * chunk * focal_sum + focal_sum_squared)
return tri
def ruggedness_cupy(chunk, slope, aspect, kernels: KernelFactory):
slope_rad = slope * (cp.pi / 180)
aspect_rad = aspect * (cp.pi / 180)
aspect_rad = cp.where(aspect_rad == -1, 0, aspect_rad)
xy = cp.sin(slope_rad)
z = cp.cos(slope_rad)
x = xy * cp.sin(aspect_rad)
y = xy * cp.cos(aspect_rad)
kernel = kernels.vrm()
# Calculate sums of x, y, and z components in the neighborhood
x_sum = cupyx.scipy.signal.convolve2d(x, kernel, mode="same")
y_sum = cupyx.scipy.signal.convolve2d(y, kernel, mode="same")
z_sum = cupyx.scipy.signal.convolve2d(z, kernel, mode="same")
# Calculate the resultant vector magnitude
vrm = 1 - np.sqrt(x_sum**2 + y_sum**2 + z_sum**2)
return vrm
def _get_xy_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> tuple[cp.array, cp.array]:
chunk_loc = block_info[0]["chunk-location"]
d = 15
cs = 3600
# Calculate safe slice bounds for edge chunks
y_start, y_end = max(0, cs * chunk_loc[0] - d), min(len(y), cs * chunk_loc[0] + cs + d)
x_start, x_end = max(0, cs * chunk_loc[1] - d), min(len(x), cs * chunk_loc[1] + cs + d)
# Extract coordinate arrays with safe bounds
y_chunk = cp.asarray(y[y_start:y_end])
x_chunk = cp.asarray(x[x_start:x_end])
# Handle cases where the extracted chunk doesn't match expected dimensions
if len(y_chunk) != chunk.shape[0] or len(x_chunk) != chunk.shape[1]:
print(
f"Adjusting coordinate chunk sizes: y_chunk {len(y_chunk)} vs chunk {chunk.shape[0]}, "
f"x_chunk {len(x_chunk)} vs chunk {chunk.shape[1]}"
)
# Pad coordinates to match chunk dimensions if needed
if len(y_chunk) < chunk.shape[0]:
# Pad with the edge values
pad_start = chunk.shape[0] - len(y_chunk)
y_chunk = cp.pad(y_chunk, (pad_start, 0), mode="edge")
if len(x_chunk) < chunk.shape[1]:
pad_start = chunk.shape[1] - len(x_chunk)
x_chunk = cp.pad(x_chunk, (pad_start, 0), mode="edge")
yy = cp.repeat(y_chunk.reshape(-1, 1), x_chunk.shape[0], axis=1)
xx = cp.repeat(x_chunk.reshape(1, -1), y_chunk.shape[0], axis=0)
return xx, yy
def _enrich_chunk(chunk: np.array, x: np.array, y: np.array, block_info=None) -> np.array:
res = 32 # 32m resolution
small_kernels = KernelFactory(res=res, size_px=3) # ~3x3 kernels (96m)
medium_kernels = KernelFactory(res=res, size_px=7) # ~7x7 kernels (224m)
large_kernels = KernelFactory(res=res, size_px=15) # ~15x15 kernels (480m)
# Check if there is data in the chunk
if np.all(np.isnan(chunk)):
# Return an array of NaNs with the expected shape
return np.full(
(12, chunk.shape[0] - 2 * large_kernels.size_px, chunk.shape[1] - 2 * large_kernels.size_px),
np.nan,
dtype=np.float32,
)
chunk = cp.asarray(chunk)
# Interpolate missing values in chunk (for patches smaller than 7x7 pixels)
mask = cp.isnan(chunk)
mask &= ~binary_dilation(binary_erosion(mask, iterations=3, brute_force=True), iterations=3, brute_force=True)
if cp.any(mask):
# Find indices of valid values
indices = distance_transform_edt(mask, return_distances=False, return_indices=True)
chunk = chunk[tuple(indices)]
# TPI
tpi_small = tpi_cupy(chunk, small_kernels)
tpi_medium = tpi_cupy(chunk, medium_kernels)
tpi_large = tpi_cupy(chunk, large_kernels)
# Slope
slope = slope_cupy(chunk, res, res)
# Aspect
aspect = aspect_cupy(chunk)
xx, yy = _get_xy_chunk(chunk, x, y, block_info)
aspect_correction = cp.arctan2(yy, xx) * (180 / cp.pi) + 90
aspect = (aspect + aspect_correction) % 360
# Curvature
curvature = curvature_cupy(chunk, res)
# TRI
tri_small = tri_cupy(chunk, small_kernels)
tri_medium = tri_cupy(chunk, medium_kernels)
tri_large = tri_cupy(chunk, large_kernels)
# Ruggedness
vrm_small = ruggedness_cupy(chunk, slope, aspect, small_kernels)
vrm_medium = ruggedness_cupy(chunk, slope, aspect, medium_kernels)
vrm_large = ruggedness_cupy(chunk, slope, aspect, large_kernels)
# Stack on GPU, then move to CPU once
res_gpu = cp.stack(
[
tpi_small,
tpi_medium,
tpi_large,
slope,
aspect,
curvature,
tri_small,
tri_medium,
tri_large,
vrm_small,
vrm_medium,
vrm_large,
],
axis=0,
)
# Trim edges on GPU
e = large_kernels.size_px
res_gpu = res_gpu[:, e:-e, e:-e]
# Single CPU transfer at the end
res = cp.asnumpy(res_gpu)
return res
def _enrich(terrain: xr.DataArray):
features = [
"tpi_small",
"tpi_medium",
"tpi_large",
"slope",
"aspect",
"curvature",
"tri_small",
"tri_medium",
"tri_large",
"vrm_small",
"vrm_medium",
"vrm_large",
]
output_chunks = (len(features), *tuple(terrain.data.chunks))
enriched_da = dask.array.map_overlap(
_enrich_chunk,
terrain.data,
# dask.array.from_array(terrain.y.data.reshape(-1, 1), chunks=(terrain.data.chunks[0][0], 3600)).repeat(
# terrain.x.size, axis=1
# ),
# dask.array.from_array(terrain.x.data.reshape(1, -1), chunks=(3600, terrain.data.chunks[1][0])).repeat(
# terrain.y.size, axis=0
# ),
x=terrain.x.to_numpy(),
y=terrain.y.to_numpy(),
depth=15, # large_kernels.size_px
chunks=output_chunks,
new_axis=0,
dtype=np.float32,
meta=np.array((), dtype=np.float32),
trim=False,
boundary=np.nan,
token="enrich_arcticdem",
)
enriched_da = xr.DataArray(
enriched_da,
dims=("feature", "y", "x"),
coords={
"feature": features,
"y": terrain.y,
"x": terrain.x,
},
)
return enriched_da, features
@cli.command()
def enrich():
with (
dd.LocalCluster(n_workers=7, threads_per_worker=2, memory_limit="30GB") as cluster,
dd.Client(cluster) as client,
):
print(client)
print(client.dashboard_link)
accessor = smart_geocubes.ArcticDEM32m(arcticdem_store)
# Garbage collect from previous runs
accessor.repo.garbage_collect(datetime.datetime.now(datetime.UTC))
adem = accessor.open_xarray()
# session = adem.repo.readonly_session("main")
# adem = xr.open_zarr(session.store, mask_and_scale=False, consolidated=False).set_coords("spatial_ref")
del adem.y.attrs["_FillValue"]
del adem.x.attrs["_FillValue"]
enriched, new_features = _enrich(adem.dem)
for feature in new_features:
adem[feature] = enriched.sel(feature=feature)
print(adem[new_features])
# subset = adem[new_features].isel(x=slice(190800, 220000), y=slice(61200, 100000))
encodings = {feature: {"compressors": [BloscCodec(clevel=5)]} for feature in new_features}
# subset.to_zarr("test2.zarr", mode="w", encoding=encodings)
session = accessor.repo.writable_session("main")
icechunk.xarray.to_icechunk(
adem[new_features],
session,
mode="a",
encoding=encodings,
)
session.commit("Add terrain features: TPI, slope, aspect, curvature, TRI, ruggedness")
print("Enrichment complete.")
if __name__ == "__main__":
cli()

View file

@ -71,6 +71,10 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
darts_counts = grid_gdf[darts_counts_columns]
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1)
darts_density_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_density")]
darts_density = grid_gdf[darts_density_columns]
grid_gdf["darts_rts_density"] = darts_density.dropna(axis=0, how="all").max(axis=1)
output_path = get_darts_rts_file(grid, level)
grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}")

View file

@ -4,6 +4,7 @@ Author: Tobias Hölzer
Date: 09. June 2025
"""
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Literal
import cartopy.crs as ccrs
@ -16,8 +17,9 @@ import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import xdggs
import xvec # noqa: F401
import xvec
from rich import pretty, print, traceback
from rich.progress import track
from shapely.geometry import Polygon
from shapely.ops import transform
from stopuhr import stopwatch
@ -64,6 +66,24 @@ def get_cell_ids(grid: Literal["hex", "healpix"], level: int):
return cell_ids
def _get_cell_polygon(hex0_cell, resolution: int) -> tuple[list[Polygon], list[str], list[float]]:
hex_batch = []
hex_id_batch = []
hex_area_batch = []
hex_cells = h3.cell_to_children(hex0_cell, resolution)
for hex_id in hex_cells:
boundary_coords = h3.cell_to_boundary(hex_id)
hex_polygon = Polygon(boundary_coords)
hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat)
hex_area = h3.cell_area(hex_id, unit="km^2")
hex_batch.append(hex_polygon)
hex_id_batch.append(hex_id)
hex_area_batch.append(hex_area)
return hex_batch, hex_id_batch, hex_area_batch
@stopwatch("Create a global hex grid")
def create_global_hex_grid(resolution):
"""Create a global hexagonal grid using H3.
@ -76,23 +96,37 @@ def create_global_hex_grid(resolution):
"""
# Generate hexagons
# For resolutions >=5, use multiprocessing to speed up the process
hex0_cells = h3.get_res0_cells()
if resolution > 0:
hex_cells = []
for hex0_cell in hex0_cells:
hex_cells.extend(h3.cell_to_children(hex0_cell, resolution))
else:
hex_cells = hex0_cells
# Initialize lists to store hex information
hex_list = []
hex_id_list = []
hex_area_list = []
if resolution >= 5:
with ProcessPoolExecutor(max_workers=20) as executor:
future_to_hex = {
executor.submit(_get_cell_polygon, hex0_cell, resolution): hex0_cell for hex0_cell in hex0_cells
}
for future in track(
as_completed(future_to_hex), description="Creating hex polygons...", total=len(hex0_cells)
):
hex_batch, hex_id_batch, hex_area_batch = future.result()
hex_list.extend(hex_batch)
hex_id_list.extend(hex_id_batch)
hex_area_list.extend(hex_area_batch)
else:
if resolution > 0:
hex_cells = []
for hex0_cell in track(hex0_cells, description="Generating hex ids...", total=len(hex0_cells)):
hex_cells.extend(h3.cell_to_children(hex0_cell, resolution))
else:
hex_cells = hex0_cells
# Convert each hex ID to a polygon
for hex_id in hex_cells:
for hex_id in track(hex_cells, description="Creating hex polygons...", total=len(hex_cells)):
boundary_coords = h3.cell_to_boundary(hex_id)
hex_polygon = Polygon(boundary_coords)
hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat)

View file

@ -36,7 +36,7 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
print(f"Predicting probabilities for {len(data)} cells...")
# Predict in batches to avoid memory issues
batch_size = 100_000
batch_size = 10_000
preds = []
for i in range(0, len(data), batch_size):
batch = data.iloc[i : i + batch_size]
@ -45,7 +45,7 @@ def predict_proba(grid: Literal["hex", "healpix"], level: int, clf: ESPAClassifi
X_batch = batch.drop(columns=cols_to_drop).dropna()
cell_ids = batch.loc[X_batch.index, "cell_id"].to_numpy()
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
X_batch = X_batch.to_numpy(dtype="float32")
X_batch = X_batch.to_numpy(dtype="float64")
X_batch = torch.asarray(X_batch, device=0)
batch_preds = clf.predict(X_batch).cpu().numpy()
batch_preds = gpd.GeoDataFrame(

View file

@ -27,6 +27,8 @@ WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
TRAINING_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
arcticdem_store = DATA_DIR / "arcticdem32m.icechunk.zarr"
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"

View file

@ -111,7 +111,10 @@ def get_available_result_files() -> list[Path]:
if result_file.exists() and state_file.exists() and preds_file.exists() and settings_file.exists():
result_files.append(search_dir)
return sorted(result_files, reverse=True) # Most recent first
def _key_func(path: Path):
return path.stat().st_mtime
return sorted(result_files, key=_key_func, reverse=True) # Most recent first
def load_and_prepare_results(file_path: Path, settings: dict, k_bin_width: int = 40) -> pd.DataFrame: