Run alphaearth as embeddings and add era5 download via CDS
This commit is contained in:
parent
bd48637491
commit
c0c3700be8
4 changed files with 2489 additions and 1783 deletions
|
|
@ -7,13 +7,18 @@ import cyclopts
|
||||||
import ee
|
import ee
|
||||||
import geemap
|
import geemap
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
from rich import pretty, traceback
|
from rich import pretty, traceback
|
||||||
|
from rich.progress import track
|
||||||
|
|
||||||
pretty.install()
|
pretty.install()
|
||||||
traceback.install()
|
traceback.install()
|
||||||
ee.Initialize(project="ee-tobias-hoelzer")
|
ee.Initialize(project="ee-tobias-hoelzer")
|
||||||
|
|
||||||
DATA_DIR = Path("data")
|
DATA_DIR = Path("data")
|
||||||
|
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
||||||
|
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def cli(grid: Literal["hex", "healpix"], level: int, year: int):
|
def cli(grid: Literal["hex", "healpix"], level: int, year: int):
|
||||||
|
|
@ -25,16 +30,16 @@ def cli(grid: Literal["hex", "healpix"], level: int, year: int):
|
||||||
year (int): The year to extract embeddings for. Must be between 2017 and 2024.
|
year (int): The year to extract embeddings for. Must be between 2017 and 2024.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
grid = gpd.read_parquet(DATA_DIR / f"grids/permafrost_{grid}{level}_grid.parquet")
|
gridname = f"permafrost_{grid}{level}"
|
||||||
eegrid = ee.FeatureCollection(grid.to_crs("epsg:4326").__geo_interface__)
|
grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet")
|
||||||
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL").filterDate(
|
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
|
||||||
f"{year}-01-01", f"{year}-12-31"
|
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
|
||||||
)
|
bands = [f"A{str(i).zfill(2)}" for i in range(64)]
|
||||||
|
|
||||||
def extract_embedding(feature):
|
def extract_embedding(feature):
|
||||||
# Filter collection by geometry
|
# Filter collection by geometry
|
||||||
geom = feature.geometry()
|
geom = feature.geometry()
|
||||||
embedding = embedding_collection.filterBounds(geom).mosaic().clip(geom)
|
embedding = embedding_collection.filterBounds(geom).mosaic()
|
||||||
# Get mean embedding value for the geometry
|
# Get mean embedding value for the geometry
|
||||||
mean_dict = embedding.reduceRegion(
|
mean_dict = embedding.reduceRegion(
|
||||||
reducer=ee.Reducer.median(),
|
reducer=ee.Reducer.median(),
|
||||||
|
|
@ -43,11 +48,36 @@ def cli(grid: Literal["hex", "healpix"], level: int, year: int):
|
||||||
# Add mean embedding values as properties to the feature
|
# Add mean embedding values as properties to the feature
|
||||||
return feature.set(mean_dict)
|
return feature.set(mean_dict)
|
||||||
|
|
||||||
eeegrid = eegrid.map(extract_embedding)
|
# Process grid in batches of 100
|
||||||
df = geemap.ee_to_df(eeegrid)
|
batch_size = 100
|
||||||
bands = [f"A{str(i).zfill(2)}" for i in range(64)]
|
all_results = []
|
||||||
|
n_batches = len(grid) // batch_size
|
||||||
|
for batch_num, batch_grid in track(
|
||||||
|
enumerate(np.array_split(grid, n_batches)),
|
||||||
|
description="Processing batches...",
|
||||||
|
total=n_batches,
|
||||||
|
):
|
||||||
|
print(f"Processing batch with {len(batch_grid)} items")
|
||||||
|
|
||||||
|
# Convert batch to EE FeatureCollection
|
||||||
|
eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__)
|
||||||
|
|
||||||
|
# Apply embedding extraction to batch
|
||||||
|
eeegrid_batch = eegrid_batch.map(extract_embedding)
|
||||||
|
df_batch = geemap.ee_to_df(eeegrid_batch)
|
||||||
|
|
||||||
|
# Save batch immediately to disk as backup
|
||||||
|
batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet"
|
||||||
|
batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left")
|
||||||
|
batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}")
|
||||||
|
|
||||||
|
# Store batch results
|
||||||
|
all_results.append(df_batch)
|
||||||
|
|
||||||
|
# Combine all batch results
|
||||||
|
df = pd.concat(all_results, ignore_index=True)
|
||||||
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
|
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
|
||||||
embeddings_on_grid.to_parquet(DATA_DIR / f"embeddings/permafrost_{grid}{level}_embeddings-{year}.parquet")
|
embeddings_on_grid.to_parquet(EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
63
cds.py
Normal file
63
cds.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""Download ERA5 data from the Copernicus Data Store.
|
||||||
|
|
||||||
|
Web platform: https://cds.climate.copernicus.eu
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cdsapi
|
||||||
|
import cyclopts
|
||||||
|
from rich import pretty, print, traceback
|
||||||
|
|
||||||
|
traceback.install()
|
||||||
|
pretty.install()
|
||||||
|
|
||||||
|
|
||||||
|
def hourly(years: str):
|
||||||
|
"""Download ERA5 data from the Copernicus Data Store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
years (str): Years to download, seperated by a '-'.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert re.compile(r"^\d{4}-\d{4}$").match(years), "Years must be in the format 'YYYY-YYYY'"
|
||||||
|
start_year, end_year = map(int, years.split("-"))
|
||||||
|
assert 1950 <= start_year <= end_year <= 2024, "Years must be between 1950 and 2024"
|
||||||
|
|
||||||
|
dataset = "reanalysis-era5-single-levels"
|
||||||
|
client = cdsapi.Client(wait_until_complete=False)
|
||||||
|
|
||||||
|
outdir = Path("/isipd/projects/p_aicore_pf/tohoel001/era5-cds").resolve()
|
||||||
|
outdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Downloading ERA5 data from {start_year} to {end_year}...")
|
||||||
|
for y in range(start_year, end_year + 1):
|
||||||
|
for month in [f"{i:02d}" for i in range(1, 13)]:
|
||||||
|
request = {
|
||||||
|
"product_type": ["reanalysis"],
|
||||||
|
"variable": [
|
||||||
|
"2m_temperature",
|
||||||
|
"total_precipitation",
|
||||||
|
"snow_depth",
|
||||||
|
"snow_density",
|
||||||
|
"snowfall",
|
||||||
|
"lake_ice_temperature",
|
||||||
|
"surface_sensible_heat_flux",
|
||||||
|
],
|
||||||
|
"year": [str(y)],
|
||||||
|
"month": [month],
|
||||||
|
"day": [f"{i:02d}" for i in range(1, 32)],
|
||||||
|
"time": [f"{i:02d}:00" for i in range(0, 24)],
|
||||||
|
"data_format": "netcdf",
|
||||||
|
"download_format": "unarchived",
|
||||||
|
"area": [85, -180, 50, 180],
|
||||||
|
}
|
||||||
|
|
||||||
|
outpath = outdir / f"era5_{y}_{month}.zip"
|
||||||
|
client.retrieve(dataset, request).download(str(outpath))
|
||||||
|
print(f"Downloaded {dataset} for {y}-{month}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cyclopts.run(hourly)
|
||||||
|
|
@ -9,12 +9,16 @@ dependencies = [
|
||||||
"aiohttp>=3.12.11",
|
"aiohttp>=3.12.11",
|
||||||
"bokeh>=3.7.3",
|
"bokeh>=3.7.3",
|
||||||
"cartopy>=0.24.1",
|
"cartopy>=0.24.1",
|
||||||
|
"cdsapi>=0.7.6",
|
||||||
"cyclopts>=3.17.0",
|
"cyclopts>=3.17.0",
|
||||||
"dask>=2025.5.1",
|
"dask>=2025.5.1",
|
||||||
"distributed>=2025.5.1",
|
"distributed>=2025.5.1",
|
||||||
"entropyc",
|
"earthengine-api>=1.6.9",
|
||||||
|
"eemont>=2025.7.1",
|
||||||
|
# "entropyc",
|
||||||
"flox>=0.10.4",
|
"flox>=0.10.4",
|
||||||
"folium>=0.19.7",
|
"folium>=0.19.7",
|
||||||
|
"geemap>=0.36.3",
|
||||||
"geopandas>=1.1.0",
|
"geopandas>=1.1.0",
|
||||||
"h3>=4.2.2",
|
"h3>=4.2.2",
|
||||||
"h5netcdf>=1.6.4",
|
"h5netcdf>=1.6.4",
|
||||||
|
|
@ -22,6 +26,7 @@ dependencies = [
|
||||||
"ipywidgets>=8.1.7",
|
"ipywidgets>=8.1.7",
|
||||||
"mapclassify>=2.9.0",
|
"mapclassify>=2.9.0",
|
||||||
"matplotlib>=3.10.3",
|
"matplotlib>=3.10.3",
|
||||||
|
"netcdf4>=1.7.2",
|
||||||
"numpy>=2.3.0",
|
"numpy>=2.3.0",
|
||||||
"odc-geo[all]>=0.4.10",
|
"odc-geo[all]>=0.4.10",
|
||||||
"pyarrow>=20.0.0",
|
"pyarrow>=20.0.0",
|
||||||
|
|
@ -39,5 +44,5 @@ dependencies = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
|
# entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
|
||||||
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue