"""Extract satellite embeddings from Google Earth Engine and map them to a grid. Author: Tobias Hölzer Date: October 2025 """ import os import warnings from pathlib import Path from typing import Literal import cyclopts import ee import geemap import geopandas as gpd import numpy as np import pandas as pd import xarray as xr from rich import pretty, print, traceback from rich.progress import track # Filter out the GeoDataFrame.swapaxes deprecation warning warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*", category=FutureWarning) pretty.install() traceback.install() ee.Initialize(project="ee-tobias-hoelzer") DATA_DIR = Path(os.environ.get("DATA_DIR", "../../data")) / "entropyc-rts" EMBEDDINGS_DIR = DATA_DIR / "embeddings" EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) cli = cyclopts.App(name="alpha-earth") @cli.command() def download(grid: Literal["hex", "healpix"], level: int, backup_intermediate: bool = False): """Extract satellite embeddings from Google Earth Engine and map them to a grid. Args: grid (Literal["hex", "healpix"]): The grid type to use. level (int): The grid level to use. backup_intermediate (bool, optional): Whether to backup intermediate results. Defaults to False. """ gridname = f"permafrost_{grid}{level}" grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet") for year in track(range(2017, 2025), total=8, description="Processing years..."): embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL") embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31") aggs = ["median", "stdDev", "min", "max", "mean", "p1", "p5", "p25", "p75", "p95", "p99"] bands = [f"A{str(i).zfill(2)}_{agg}" for i in range(64) for agg in aggs] def extract_embedding(feature): # Filter collection by geometry geom = feature.geometry() embedding = embedding_collection.filterBounds(geom).mosaic() # Get mean embedding value for the geometry mean_dict = embedding.reduceRegion( reducer=ee.Reducer.median() .combine(ee.Reducer.stdDev(), sharedInputs=True) .combine(ee.Reducer.minMax(), sharedInputs=True) .combine(ee.Reducer.mean(), sharedInputs=True) .combine(ee.Reducer.percentile([1, 5, 25, 75, 95, 99]), sharedInputs=True), geometry=geom, ) # Add mean embedding values as properties to the feature return feature.set(mean_dict) # Process grid in batches of 100 batch_size = 100 all_results = [] n_batches = len(grid) // batch_size for batch_num, batch_grid in track( enumerate(np.array_split(grid, n_batches)), description="Processing batches...", total=n_batches, ): # Convert batch to EE FeatureCollection eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__) # Apply embedding extraction to batch eeegrid_batch = eegrid_batch.map(extract_embedding) df_batch = geemap.ee_to_df(eeegrid_batch) # Store batch results all_results.append(df_batch) # Save batch immediately to disk as backup if backup_intermediate: batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet" batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left") batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}") # Combine all batch results df = pd.concat(all_results, ignore_index=True) embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left") embeddings_file = EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet" embeddings_on_grid.to_parquet(embeddings_file) print(f"Saved embeddings for year {year} to {embeddings_file.resolve()}.") @cli.command() def combine_to_zarr(grid: Literal["hex", "healpix"], level: int): """Combine yearly embeddings parquet files into a single zarr store. Args: grid (Literal["hex", "healpix"]): The grid type to use. level (int): The grid level to use. """ embs = gpd.read_parquet(DATA_DIR / "embeddings" / f"permafrost_{grid}{level}_embeddings-2017.parquet") # ? Converting cell IDs from hex strings to integers for xdggs compatibility cells = [int(cid, 16) for cid in embs.cell_id.to_list()] years = list(range(2017, 2025)) aggs = ["median", "stdDev", "min", "max", "mean", "p1", "p5", "p25", "p75", "p95", "p99"] bands = [f"A{str(i).zfill(2)}" for i in range(64)] a = xr.DataArray( np.nan, dims=("year", "cell", "band", "agg"), coords={"year": years, "cell": cells, "band": bands, "agg": aggs}, ) # ? These attributes are needed for xdggs a.cell.attrs = { "grid_name": "h3" if grid == "hex" else "healpix", "level": level, } if grid == "healpix": a.cell.attrs["indexing_scheme"] = "nested" for year in track(years, total=len(years), description="Processing years..."): embs = gpd.read_parquet(DATA_DIR / "embeddings" / f"permafrost_{grid}{level}_embeddings-{year}.parquet") for band in bands: for agg in aggs: col = f"{band}_{agg}" a.loc[{"band": band, "agg": agg, "year": year}] = embs[col].to_list() zarr_path = EMBEDDINGS_DIR / f"permafrost_{grid}{level}_embeddings.zarr" a.to_zarr(zarr_path, consolidated=False, mode="w") print(f"Saved combined embeddings to {zarr_path.resolve()}.") def main(): # noqa: D103 cli() if __name__ == "__main__": main()