entropice/steps/s1_0_alphaearth/alphaearth.py
2025-10-24 18:34:37 +02:00

151 lines
5.8 KiB
Python

"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
Author: Tobias Hölzer
Date: October 2025
"""
import os
import warnings
from pathlib import Path
from typing import Literal
import cyclopts
import ee
import geemap
import geopandas as gpd
import numpy as np
import pandas as pd
import xarray as xr
from rich import pretty, print, traceback
from rich.progress import track
# Filter out the GeoDataFrame.swapaxes deprecation warning
warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*", category=FutureWarning)
pretty.install()
traceback.install()
ee.Initialize(project="ee-tobias-hoelzer")
DATA_DIR = Path(os.environ.get("DATA_DIR", "../../data")) / "entropyc-rts"
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
cli = cyclopts.App(name="alpha-earth")
@cli.command()
def download(grid: Literal["hex", "healpix"], level: int, backup_intermediate: bool = False):
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
backup_intermediate (bool, optional): Whether to backup intermediate results. Defaults to False.
"""
gridname = f"permafrost_{grid}{level}"
grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet")
for year in track(range(2017, 2025), total=8, description="Processing years..."):
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
aggs = ["median", "stdDev", "min", "max", "mean", "p1", "p5", "p25", "p75", "p95", "p99"]
bands = [f"A{str(i).zfill(2)}_{agg}" for i in range(64) for agg in aggs]
def extract_embedding(feature):
# Filter collection by geometry
geom = feature.geometry()
embedding = embedding_collection.filterBounds(geom).mosaic()
# Get mean embedding value for the geometry
mean_dict = embedding.reduceRegion(
reducer=ee.Reducer.median()
.combine(ee.Reducer.stdDev(), sharedInputs=True)
.combine(ee.Reducer.minMax(), sharedInputs=True)
.combine(ee.Reducer.mean(), sharedInputs=True)
.combine(ee.Reducer.percentile([1, 5, 25, 75, 95, 99]), sharedInputs=True),
geometry=geom,
)
# Add mean embedding values as properties to the feature
return feature.set(mean_dict)
# Process grid in batches of 100
batch_size = 100
all_results = []
n_batches = len(grid) // batch_size
for batch_num, batch_grid in track(
enumerate(np.array_split(grid, n_batches)),
description="Processing batches...",
total=n_batches,
):
# Convert batch to EE FeatureCollection
eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__)
# Apply embedding extraction to batch
eeegrid_batch = eegrid_batch.map(extract_embedding)
df_batch = geemap.ee_to_df(eeegrid_batch)
# Store batch results
all_results.append(df_batch)
# Save batch immediately to disk as backup
if backup_intermediate:
batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet"
batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left")
batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}")
# Combine all batch results
df = pd.concat(all_results, ignore_index=True)
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
embeddings_file = EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet"
embeddings_on_grid.to_parquet(embeddings_file)
print(f"Saved embeddings for year {year} to {embeddings_file.resolve()}.")
@cli.command()
def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
"""Combine yearly embeddings parquet files into a single zarr store.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
"""
embs = gpd.read_parquet(DATA_DIR / "embeddings" / f"permafrost_{grid}{level}_embeddings-2017.parquet")
# ? Converting cell IDs from hex strings to integers for xdggs compatibility
cells = [int(cid, 16) for cid in embs.cell_id.to_list()]
years = list(range(2017, 2025))
aggs = ["median", "stdDev", "min", "max", "mean", "p1", "p5", "p25", "p75", "p95", "p99"]
bands = [f"A{str(i).zfill(2)}" for i in range(64)]
a = xr.DataArray(
np.nan,
dims=("year", "cell", "band", "agg"),
coords={"year": years, "cell": cells, "band": bands, "agg": aggs},
)
# ? These attributes are needed for xdggs
a.cell.attrs = {
"grid_name": "h3" if grid == "hex" else "healpix",
"level": level,
}
if grid == "healpix":
a.cell.attrs["indexing_scheme"] = "nested"
for year in track(years, total=len(years), description="Processing years..."):
embs = gpd.read_parquet(DATA_DIR / "embeddings" / f"permafrost_{grid}{level}_embeddings-{year}.parquet")
for band in bands:
for agg in aggs:
col = f"{band}_{agg}"
a.loc[{"band": band, "agg": agg, "year": year}] = embs[col].to_list()
zarr_path = EMBEDDINGS_DIR / f"permafrost_{grid}{level}_embeddings.zarr"
a.to_zarr(zarr_path, consolidated=False, mode="w")
print(f"Saved combined embeddings to {zarr_path.resolve()}.")
def main(): # noqa: D103
cli()
if __name__ == "__main__":
main()