"""Extract satellite embeddings from Google Earth Engine and map them to a grid.""" from pathlib import Path from typing import Literal import cyclopts import ee import geemap import geopandas as gpd import numpy as np import pandas as pd from rich import pretty, traceback from rich.progress import track pretty.install() traceback.install() ee.Initialize(project="ee-tobias-hoelzer") DATA_DIR = Path("data") EMBEDDINGS_DIR = DATA_DIR / "embeddings" EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) def cli(grid: Literal["hex", "healpix"], level: int, year: int): """Extract satellite embeddings from Google Earth Engine and map them to a grid. Args: grid (Literal["hex", "healpix"]): The grid type to use. level (int): The grid level to use. year (int): The year to extract embeddings for. Must be between 2017 and 2024. """ gridname = f"permafrost_{grid}{level}" grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet") embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL") embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31") bands = [f"A{str(i).zfill(2)}" for i in range(64)] def extract_embedding(feature): # Filter collection by geometry geom = feature.geometry() embedding = embedding_collection.filterBounds(geom).mosaic() # Get mean embedding value for the geometry mean_dict = embedding.reduceRegion( reducer=ee.Reducer.median(), geometry=geom, ) # Add mean embedding values as properties to the feature return feature.set(mean_dict) # Process grid in batches of 100 batch_size = 100 all_results = [] n_batches = len(grid) // batch_size for batch_num, batch_grid in track( enumerate(np.array_split(grid, n_batches)), description="Processing batches...", total=n_batches, ): print(f"Processing batch with {len(batch_grid)} items") # Convert batch to EE FeatureCollection eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__) # Apply embedding extraction to batch eeegrid_batch = eegrid_batch.map(extract_embedding) df_batch = geemap.ee_to_df(eeegrid_batch) # Save batch immediately to disk as backup batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet" batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left") batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}") # Store batch results all_results.append(df_batch) # Combine all batch results df = pd.concat(all_results, ignore_index=True) embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left") embeddings_on_grid.to_parquet(EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet") if __name__ == "__main__": cyclopts.run(cli)