Reduce compute time of alphaearth

This commit is contained in:
Tobias Hölzer 2025-11-23 16:18:23 +01:00
parent e5382670ec
commit 18cc1b8601
2 changed files with 19 additions and 13 deletions

View file

@ -3,21 +3,19 @@
pixi run alpha-earth download --grid hex --level 3
pixi run alpha-earth download --grid hex --level 4
pixi run alpha-earth download --grid hex --level 5
pixi run alpha-earth download --grid hex --level 6
pixi run alpha-earth download --grid healpix --level 6
pixi run alpha-earth download --grid healpix --level 7
pixi run alpha-earth download --grid healpix --level 8
pixi run alpha-earth download --grid healpix --level 9
pixi run alpha-earth download --grid healpix --level 10
pixi run alpha-earth combine-to-zarr --grid hex --level 3
pixi run alpha-earth combine-to-zarr --grid hex --level 4
pixi run alpha-earth combine-to-zarr --grid hex --level 5
pixi run alpha-earth combine-to-zarr --grid hex --level 6
pixi run alpha-earth combine-to-zarr --grid healpix --level 6
pixi run alpha-earth combine-to-zarr --grid healpix --level 7
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
pixi run alpha-earth download --grid hex --level 6
pixi run alpha-earth download --grid healpix --level 10
pixi run alpha-earth combine-to-zarr --grid hex --level 6
pixi run alpha-earth combine-to-zarr --grid healpix --level 10

View file

@ -44,10 +44,18 @@ def download(grid: Literal["hex", "healpix"], level: int):
"""
grid_gdf = grids.open(grid, level)
for year in track(range(2017, 2025), total=8, description="Processing years..."):
# Hardcoded scale factors, depending on grid and level, so that approx. 10000px are sampled per grid cell
scale_factors = {
"hex": {3: 1600, 4: 600, 5: 240, 6: 90},
"healpix": {6: 1600, 7: 800, 8: 400, 9: 200, 10: 100},
}
scale_factor = scale_factors[grid][level]
print(f"Using scale factor of {scale_factor} for grid {grid} at level {level}.")
for year in track(range(2018, 2025), total=7, description="Processing years..."):
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
aggs = ["median", "stdDev", "min", "max", "mean", "p1", "p5", "p25", "p75", "p95", "p99"]
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
bands = [f"A{str(i).zfill(2)}_{agg}" for i in range(64) for agg in aggs]
def extract_embedding(feature):
@ -56,14 +64,14 @@ def download(grid: Literal["hex", "healpix"], level: int):
embedding = embedding_collection.filterBounds(geom).mosaic()
# Get mean embedding value for the geometry
mean_dict = embedding.reduceRegion(
reducer=ee.Reducer.median()
reducer=ee.Reducer.mean()
.combine(ee.Reducer.stdDev(), sharedInputs=True)
.combine(ee.Reducer.minMax(), sharedInputs=True)
.combine(ee.Reducer.mean(), sharedInputs=True)
.combine(ee.Reducer.count(), sharedInputs=True)
.combine(ee.Reducer.median(), sharedInputs=True)
.combine(ee.Reducer.percentile([1, 5, 25, 75, 95, 99]), sharedInputs=True),
geometry=geom,
scale=10,
bestEffort=True,
scale=scale_factor,
)
# Add mean embedding values as properties to the feature
return feature.set(mean_dict)
@ -105,8 +113,8 @@ def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
"""
cell_ids = grids.get_cell_ids(grid, level)
years = list(range(2017, 2025))
aggs = ["median", "stdDev", "min", "max", "mean", "p1", "p5", "p25", "p75", "p95", "p99"]
years = list(range(2018, 2025))
aggs = ["mean", "stdDev", "min", "max", "count", "median", "p1", "p5", "p25", "p75", "p95", "p99"]
bands = [f"A{str(i).zfill(2)}" for i in range(64)]
a = xr.DataArray(