entropice/alphaearth.py

84 lines
2.9 KiB
Python

"""Extract satellite embeddings from Google Earth Engine and map them to a grid."""
from pathlib import Path
from typing import Literal
import cyclopts
import ee
import geemap
import geopandas as gpd
import numpy as np
import pandas as pd
from rich import pretty, traceback
from rich.progress import track
pretty.install()
traceback.install()
ee.Initialize(project="ee-tobias-hoelzer")
DATA_DIR = Path("data")
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
def cli(grid: Literal["hex", "healpix"], level: int, year: int):
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
year (int): The year to extract embeddings for. Must be between 2017 and 2024.
"""
gridname = f"permafrost_{grid}{level}"
grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet")
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
bands = [f"A{str(i).zfill(2)}" for i in range(64)]
def extract_embedding(feature):
# Filter collection by geometry
geom = feature.geometry()
embedding = embedding_collection.filterBounds(geom).mosaic()
# Get mean embedding value for the geometry
mean_dict = embedding.reduceRegion(
reducer=ee.Reducer.median(),
geometry=geom,
)
# Add mean embedding values as properties to the feature
return feature.set(mean_dict)
# Process grid in batches of 100
batch_size = 100
all_results = []
n_batches = len(grid) // batch_size
for batch_num, batch_grid in track(
enumerate(np.array_split(grid, n_batches)),
description="Processing batches...",
total=n_batches,
):
print(f"Processing batch with {len(batch_grid)} items")
# Convert batch to EE FeatureCollection
eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__)
# Apply embedding extraction to batch
eeegrid_batch = eegrid_batch.map(extract_embedding)
df_batch = geemap.ee_to_df(eeegrid_batch)
# Save batch immediately to disk as backup
batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet"
batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left")
batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}")
# Store batch results
all_results.append(df_batch)
# Combine all batch results
df = pd.concat(all_results, ignore_index=True)
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
embeddings_on_grid.to_parquet(EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet")
if __name__ == "__main__":
cyclopts.run(cli)