"""Extract satellite embeddings from Google Earth Engine and map them to a grid.""" import os from pathlib import Path from typing import Literal import cyclopts import ee import geemap import geopandas as gpd import numpy as np import pandas as pd from rich import pretty, traceback from rich.progress import track pretty.install() traceback.install() ee.Initialize(project="ee-tobias-hoelzer") DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) / "entropyc-rts" EMBEDDINGS_DIR = DATA_DIR / "embeddings" EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) def cli(grid: Literal["hex", "healpix"], level: int, backup_intermediate: bool = False): """Extract satellite embeddings from Google Earth Engine and map them to a grid. Args: grid (Literal["hex", "healpix"]): The grid type to use. level (int): The grid level to use. backup_intermediate (bool, optional): Whether to backup intermediate results. Defaults to False. """ gridname = f"permafrost_{grid}{level}" grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet") for year in track(range(2022, 2025), total=3, description="Processing years..."): embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL") embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31") bands = [f"A{str(i).zfill(2)}" for i in range(64)] def extract_embedding(feature): # Filter collection by geometry geom = feature.geometry() embedding = embedding_collection.filterBounds(geom).mosaic() # Get mean embedding value for the geometry mean_dict = embedding.reduceRegion( reducer=ee.Reducer.median(), geometry=geom, ) # Add mean embedding values as properties to the feature return feature.set(mean_dict) # Process grid in batches of 100 batch_size = 100 all_results = [] n_batches = len(grid) // batch_size for batch_num, batch_grid in track( enumerate(np.array_split(grid, n_batches)), description="Processing batches...", total=n_batches, ): # Convert batch to EE FeatureCollection eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__) # Apply embedding extraction to batch eeegrid_batch = eegrid_batch.map(extract_embedding) df_batch = geemap.ee_to_df(eeegrid_batch) # Store batch results all_results.append(df_batch) # Save batch immediately to disk as backup if backup_intermediate: batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet" batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left") batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}") # Combine all batch results df = pd.concat(all_results, ignore_index=True) embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left") embeddings_on_grid.to_parquet(EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet") if __name__ == "__main__": cyclopts.run(cli)