151 lines
5.8 KiB
Python
151 lines
5.8 KiB
Python
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
|
|
|
|
Author: Tobias Hölzer
|
|
Date: October 2025
|
|
"""
|
|
|
|
import os
|
|
import warnings
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
|
|
import cyclopts
|
|
import ee
|
|
import geemap
|
|
import geopandas as gpd
|
|
import numpy as np
|
|
import pandas as pd
|
|
import xarray as xr
|
|
from rich import pretty, print, traceback
|
|
from rich.progress import track
|
|
|
|
# Filter out the GeoDataFrame.swapaxes deprecation warning
|
|
warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*", category=FutureWarning)
|
|
|
|
pretty.install()
|
|
traceback.install()
|
|
ee.Initialize(project="ee-tobias-hoelzer")
|
|
|
|
DATA_DIR = Path(os.environ.get("DATA_DIR", "../../data")) / "entropyc-rts"
|
|
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
|
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
cli = cyclopts.App(name="alpha-earth")
|
|
|
|
|
|
@cli.command()
|
|
def download(grid: Literal["hex", "healpix"], level: int, backup_intermediate: bool = False):
|
|
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
|
|
|
|
Args:
|
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
|
level (int): The grid level to use.
|
|
backup_intermediate (bool, optional): Whether to backup intermediate results. Defaults to False.
|
|
|
|
"""
|
|
gridname = f"permafrost_{grid}{level}"
|
|
grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet")
|
|
|
|
for year in track(range(2017, 2025), total=8, description="Processing years..."):
|
|
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
|
|
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
|
|
aggs = ["median", "stdDev", "min", "max", "mean", "p1", "p5", "p25", "p75", "p95", "p99"]
|
|
bands = [f"A{str(i).zfill(2)}_{agg}" for i in range(64) for agg in aggs]
|
|
|
|
def extract_embedding(feature):
|
|
# Filter collection by geometry
|
|
geom = feature.geometry()
|
|
embedding = embedding_collection.filterBounds(geom).mosaic()
|
|
# Get mean embedding value for the geometry
|
|
mean_dict = embedding.reduceRegion(
|
|
reducer=ee.Reducer.median()
|
|
.combine(ee.Reducer.stdDev(), sharedInputs=True)
|
|
.combine(ee.Reducer.minMax(), sharedInputs=True)
|
|
.combine(ee.Reducer.mean(), sharedInputs=True)
|
|
.combine(ee.Reducer.percentile([1, 5, 25, 75, 95, 99]), sharedInputs=True),
|
|
geometry=geom,
|
|
)
|
|
# Add mean embedding values as properties to the feature
|
|
return feature.set(mean_dict)
|
|
|
|
# Process grid in batches of 100
|
|
batch_size = 100
|
|
all_results = []
|
|
n_batches = len(grid) // batch_size
|
|
for batch_num, batch_grid in track(
|
|
enumerate(np.array_split(grid, n_batches)),
|
|
description="Processing batches...",
|
|
total=n_batches,
|
|
):
|
|
# Convert batch to EE FeatureCollection
|
|
eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__)
|
|
|
|
# Apply embedding extraction to batch
|
|
eeegrid_batch = eegrid_batch.map(extract_embedding)
|
|
df_batch = geemap.ee_to_df(eeegrid_batch)
|
|
|
|
# Store batch results
|
|
all_results.append(df_batch)
|
|
|
|
# Save batch immediately to disk as backup
|
|
if backup_intermediate:
|
|
batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet"
|
|
batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left")
|
|
batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}")
|
|
|
|
# Combine all batch results
|
|
df = pd.concat(all_results, ignore_index=True)
|
|
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
|
|
embeddings_file = EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet"
|
|
embeddings_on_grid.to_parquet(embeddings_file)
|
|
print(f"Saved embeddings for year {year} to {embeddings_file.resolve()}.")
|
|
|
|
|
|
@cli.command()
|
|
def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
|
|
"""Combine yearly embeddings parquet files into a single zarr store.
|
|
|
|
Args:
|
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
|
level (int): The grid level to use.
|
|
|
|
"""
|
|
embs = gpd.read_parquet(DATA_DIR / "embeddings" / f"permafrost_{grid}{level}_embeddings-2017.parquet")
|
|
# ? Converting cell IDs from hex strings to integers for xdggs compatibility
|
|
cells = [int(cid, 16) for cid in embs.cell_id.to_list()]
|
|
years = list(range(2017, 2025))
|
|
aggs = ["median", "stdDev", "min", "max", "mean", "p1", "p5", "p25", "p75", "p95", "p99"]
|
|
bands = [f"A{str(i).zfill(2)}" for i in range(64)]
|
|
|
|
a = xr.DataArray(
|
|
np.nan,
|
|
dims=("year", "cell", "band", "agg"),
|
|
coords={"year": years, "cell": cells, "band": bands, "agg": aggs},
|
|
)
|
|
# ? These attributes are needed for xdggs
|
|
a.cell.attrs = {
|
|
"grid_name": "h3" if grid == "hex" else "healpix",
|
|
"level": level,
|
|
}
|
|
if grid == "healpix":
|
|
a.cell.attrs["indexing_scheme"] = "nested"
|
|
|
|
for year in track(years, total=len(years), description="Processing years..."):
|
|
embs = gpd.read_parquet(DATA_DIR / "embeddings" / f"permafrost_{grid}{level}_embeddings-{year}.parquet")
|
|
for band in bands:
|
|
for agg in aggs:
|
|
col = f"{band}_{agg}"
|
|
a.loc[{"band": band, "agg": agg, "year": year}] = embs[col].to_list()
|
|
|
|
zarr_path = EMBEDDINGS_DIR / f"permafrost_{grid}{level}_embeddings.zarr"
|
|
a.to_zarr(zarr_path, consolidated=False, mode="w")
|
|
print(f"Saved combined embeddings to {zarr_path.resolve()}.")
|
|
|
|
|
|
def main(): # noqa: D103
|
|
cli()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|