Run alphaearth as embeddings and add era5 download via CDS

This commit is contained in:
Tobias Hölzer 2025-09-29 18:45:57 +02:00
parent bd48637491
commit c0c3700be8
4 changed files with 2489 additions and 1783 deletions

View file

@ -7,13 +7,18 @@ import cyclopts
import ee import ee
import geemap import geemap
import geopandas as gpd import geopandas as gpd
import numpy as np
import pandas as pd
from rich import pretty, traceback from rich import pretty, traceback
from rich.progress import track
pretty.install() pretty.install()
traceback.install() traceback.install()
ee.Initialize(project="ee-tobias-hoelzer") ee.Initialize(project="ee-tobias-hoelzer")
DATA_DIR = Path("data") DATA_DIR = Path("data")
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
def cli(grid: Literal["hex", "healpix"], level: int, year: int): def cli(grid: Literal["hex", "healpix"], level: int, year: int):
@ -25,16 +30,16 @@ def cli(grid: Literal["hex", "healpix"], level: int, year: int):
year (int): The year to extract embeddings for. Must be between 2017 and 2024. year (int): The year to extract embeddings for. Must be between 2017 and 2024.
""" """
grid = gpd.read_parquet(DATA_DIR / f"grids/permafrost_{grid}{level}_grid.parquet") gridname = f"permafrost_{grid}{level}"
eegrid = ee.FeatureCollection(grid.to_crs("epsg:4326").__geo_interface__) grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet")
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL").filterDate( embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
f"{year}-01-01", f"{year}-12-31" embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
) bands = [f"A{str(i).zfill(2)}" for i in range(64)]
def extract_embedding(feature): def extract_embedding(feature):
# Filter collection by geometry # Filter collection by geometry
geom = feature.geometry() geom = feature.geometry()
embedding = embedding_collection.filterBounds(geom).mosaic().clip(geom) embedding = embedding_collection.filterBounds(geom).mosaic()
# Get mean embedding value for the geometry # Get mean embedding value for the geometry
mean_dict = embedding.reduceRegion( mean_dict = embedding.reduceRegion(
reducer=ee.Reducer.median(), reducer=ee.Reducer.median(),
@ -43,11 +48,36 @@ def cli(grid: Literal["hex", "healpix"], level: int, year: int):
# Add mean embedding values as properties to the feature # Add mean embedding values as properties to the feature
return feature.set(mean_dict) return feature.set(mean_dict)
eeegrid = eegrid.map(extract_embedding) # Process grid in batches of 100
df = geemap.ee_to_df(eeegrid) batch_size = 100
bands = [f"A{str(i).zfill(2)}" for i in range(64)] all_results = []
n_batches = len(grid) // batch_size
for batch_num, batch_grid in track(
enumerate(np.array_split(grid, n_batches)),
description="Processing batches...",
total=n_batches,
):
print(f"Processing batch with {len(batch_grid)} items")
# Convert batch to EE FeatureCollection
eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__)
# Apply embedding extraction to batch
eeegrid_batch = eegrid_batch.map(extract_embedding)
df_batch = geemap.ee_to_df(eeegrid_batch)
# Save batch immediately to disk as backup
batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet"
batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left")
batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}")
# Store batch results
all_results.append(df_batch)
# Combine all batch results
df = pd.concat(all_results, ignore_index=True)
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left") embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
embeddings_on_grid.to_parquet(DATA_DIR / f"embeddings/permafrost_{grid}{level}_embeddings-{year}.parquet") embeddings_on_grid.to_parquet(EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet")
if __name__ == "__main__": if __name__ == "__main__":

63
cds.py Normal file
View file

@ -0,0 +1,63 @@
"""Download ERA5 data from the Copernicus Data Store.
Web platform: https://cds.climate.copernicus.eu
"""
import re
from pathlib import Path
import cdsapi
import cyclopts
from rich import pretty, print, traceback
traceback.install()
pretty.install()
def hourly(years: str):
"""Download ERA5 data from the Copernicus Data Store.
Args:
years (str): Years to download, seperated by a '-'.
"""
assert re.compile(r"^\d{4}-\d{4}$").match(years), "Years must be in the format 'YYYY-YYYY'"
start_year, end_year = map(int, years.split("-"))
assert 1950 <= start_year <= end_year <= 2024, "Years must be between 1950 and 2024"
dataset = "reanalysis-era5-single-levels"
client = cdsapi.Client(wait_until_complete=False)
outdir = Path("/isipd/projects/p_aicore_pf/tohoel001/era5-cds").resolve()
outdir.mkdir(parents=True, exist_ok=True)
print(f"Downloading ERA5 data from {start_year} to {end_year}...")
for y in range(start_year, end_year + 1):
for month in [f"{i:02d}" for i in range(1, 13)]:
request = {
"product_type": ["reanalysis"],
"variable": [
"2m_temperature",
"total_precipitation",
"snow_depth",
"snow_density",
"snowfall",
"lake_ice_temperature",
"surface_sensible_heat_flux",
],
"year": [str(y)],
"month": [month],
"day": [f"{i:02d}" for i in range(1, 32)],
"time": [f"{i:02d}:00" for i in range(0, 24)],
"data_format": "netcdf",
"download_format": "unarchived",
"area": [85, -180, 50, 180],
}
outpath = outdir / f"era5_{y}_{month}.zip"
client.retrieve(dataset, request).download(str(outpath))
print(f"Downloaded {dataset} for {y}-{month}")
if __name__ == "__main__":
cyclopts.run(hourly)

View file

@ -9,12 +9,16 @@ dependencies = [
"aiohttp>=3.12.11", "aiohttp>=3.12.11",
"bokeh>=3.7.3", "bokeh>=3.7.3",
"cartopy>=0.24.1", "cartopy>=0.24.1",
"cdsapi>=0.7.6",
"cyclopts>=3.17.0", "cyclopts>=3.17.0",
"dask>=2025.5.1", "dask>=2025.5.1",
"distributed>=2025.5.1", "distributed>=2025.5.1",
"entropyc", "earthengine-api>=1.6.9",
"eemont>=2025.7.1",
# "entropyc",
"flox>=0.10.4", "flox>=0.10.4",
"folium>=0.19.7", "folium>=0.19.7",
"geemap>=0.36.3",
"geopandas>=1.1.0", "geopandas>=1.1.0",
"h3>=4.2.2", "h3>=4.2.2",
"h5netcdf>=1.6.4", "h5netcdf>=1.6.4",
@ -22,6 +26,7 @@ dependencies = [
"ipywidgets>=8.1.7", "ipywidgets>=8.1.7",
"mapclassify>=2.9.0", "mapclassify>=2.9.0",
"matplotlib>=3.10.3", "matplotlib>=3.10.3",
"netcdf4>=1.7.2",
"numpy>=2.3.0", "numpy>=2.3.0",
"odc-geo[all]>=0.4.10", "odc-geo[all]>=0.4.10",
"pyarrow>=20.0.0", "pyarrow>=20.0.0",
@ -39,5 +44,5 @@ dependencies = [
] ]
[tool.uv.sources] [tool.uv.sources]
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" } # entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
xanimate = { git = "https://github.com/davbyr/xAnimate" } xanimate = { git = "https://github.com/davbyr/xAnimate" }

4150
uv.lock generated

File diff suppressed because it is too large Load diff