Make era5 and alphaearth downloads work

This commit is contained in:
Tobias Hölzer 2025-10-01 14:44:24 +02:00
parent c0c3700be8
commit 2af5c011a3
6 changed files with 441 additions and 196 deletions

View file

@ -1,5 +1,6 @@
"""Extract satellite embeddings from Google Earth Engine and map them to a grid."""
import os
from pathlib import Path
from typing import Literal
@ -16,68 +17,69 @@ pretty.install()
traceback.install()
ee.Initialize(project="ee-tobias-hoelzer")
DATA_DIR = Path("data")
DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) / "entropyc-rts"
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
def cli(grid: Literal["hex", "healpix"], level: int, year: int):
def cli(grid: Literal["hex", "healpix"], level: int, backup_intermediate: bool = False):
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
year (int): The year to extract embeddings for. Must be between 2017 and 2024.
backup_intermediate (bool, optional): Whether to backup intermediate results. Defaults to False.
"""
gridname = f"permafrost_{grid}{level}"
grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet")
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
bands = [f"A{str(i).zfill(2)}" for i in range(64)]
def extract_embedding(feature):
# Filter collection by geometry
geom = feature.geometry()
embedding = embedding_collection.filterBounds(geom).mosaic()
# Get mean embedding value for the geometry
mean_dict = embedding.reduceRegion(
reducer=ee.Reducer.median(),
geometry=geom,
)
# Add mean embedding values as properties to the feature
return feature.set(mean_dict)
for year in track(range(2022, 2025), total=3, description="Processing years..."):
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
bands = [f"A{str(i).zfill(2)}" for i in range(64)]
# Process grid in batches of 100
batch_size = 100
all_results = []
n_batches = len(grid) // batch_size
for batch_num, batch_grid in track(
enumerate(np.array_split(grid, n_batches)),
description="Processing batches...",
total=n_batches,
):
print(f"Processing batch with {len(batch_grid)} items")
def extract_embedding(feature):
# Filter collection by geometry
geom = feature.geometry()
embedding = embedding_collection.filterBounds(geom).mosaic()
# Get mean embedding value for the geometry
mean_dict = embedding.reduceRegion(
reducer=ee.Reducer.median(),
geometry=geom,
)
# Add mean embedding values as properties to the feature
return feature.set(mean_dict)
# Convert batch to EE FeatureCollection
eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__)
# Process grid in batches of 100
batch_size = 100
all_results = []
n_batches = len(grid) // batch_size
for batch_num, batch_grid in track(
enumerate(np.array_split(grid, n_batches)),
description="Processing batches...",
total=n_batches,
):
# Convert batch to EE FeatureCollection
eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__)
# Apply embedding extraction to batch
eeegrid_batch = eegrid_batch.map(extract_embedding)
df_batch = geemap.ee_to_df(eeegrid_batch)
# Apply embedding extraction to batch
eeegrid_batch = eegrid_batch.map(extract_embedding)
df_batch = geemap.ee_to_df(eeegrid_batch)
# Save batch immediately to disk as backup
batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet"
batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left")
batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}")
# Store batch results
all_results.append(df_batch)
# Store batch results
all_results.append(df_batch)
# Save batch immediately to disk as backup
if backup_intermediate:
batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet"
batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left")
batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}")
# Combine all batch results
df = pd.concat(all_results, ignore_index=True)
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
embeddings_on_grid.to_parquet(EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet")
# Combine all batch results
df = pd.concat(all_results, ignore_index=True)
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
embeddings_on_grid.to_parquet(EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet")
if __name__ == "__main__":