Create Ensemble Datasets
This commit is contained in:
parent
33c9667383
commit
67030c9f0d
10 changed files with 839 additions and 626 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -25,3 +25,6 @@ playground.ipynb
|
||||||
# pixi environments
|
# pixi environments
|
||||||
.pixi
|
.pixi
|
||||||
*.egg-info
|
*.egg-info
|
||||||
|
|
||||||
|
# Disable all notebook for now
|
||||||
|
notebooks/*.ipynb
|
||||||
|
|
@ -45,7 +45,7 @@ dependencies = [
|
||||||
"ultraplot>=1.63.0",
|
"ultraplot>=1.63.0",
|
||||||
"xanimate",
|
"xanimate",
|
||||||
"xarray>=2025.9.0",
|
"xarray>=2025.9.0",
|
||||||
"xdggs>=0.2.1",
|
"xdggs",
|
||||||
"xvec>=0.5.1",
|
"xvec>=0.5.1",
|
||||||
"zarr[remote]>=3.1.3",
|
"zarr[remote]>=3.1.3",
|
||||||
"geocube>=0.7.1,<0.8",
|
"geocube>=0.7.1,<0.8",
|
||||||
|
|
@ -57,12 +57,16 @@ dependencies = [
|
||||||
"xgboost>=3.1.1,<4",
|
"xgboost>=3.1.1,<4",
|
||||||
"s3fs>=2025.10.0,<2026",
|
"s3fs>=2025.10.0,<2026",
|
||||||
"xarray-spatial",
|
"xarray-spatial",
|
||||||
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3", "antimeridian>=0.4.5,<0.5", "duckdb>=1.4.2,<2",
|
"cupy-xarray>=0.1.4,<0.2",
|
||||||
|
"memray>=1.19.1,<2",
|
||||||
|
"xarray-histogram>=0.2.2,<0.3",
|
||||||
|
"antimeridian>=0.4.5,<0.5",
|
||||||
|
"duckdb>=1.4.2,<2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
create-grid = "entropice.grids:main"
|
create-grid = "entropice.grids:main"
|
||||||
darts = "entropice.darts:main"
|
darts = "entropice.darts:cli"
|
||||||
alpha-earth = "entropice.alphaearth:main"
|
alpha-earth = "entropice.alphaearth:main"
|
||||||
era5 = "entropice.era5:cli"
|
era5 = "entropice.era5:cli"
|
||||||
arcticdem = "entropice.arcticdem:cli"
|
arcticdem = "entropice.arcticdem:cli"
|
||||||
|
|
@ -86,6 +90,7 @@ entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "re
|
||||||
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
|
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
|
||||||
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
||||||
xdem = { git = "https://github.com/GlacioHack/xdem" }
|
xdem = { git = "https://github.com/GlacioHack/xdem" }
|
||||||
|
xdggs = { git = "https://github.com/relativityhd/xdggs", branch = "feature/make-plotting-useful" }
|
||||||
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
|
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
|
||||||
cudf-cu12 = { index = "nvidia" }
|
cudf-cu12 = { index = "nvidia" }
|
||||||
cuml-cu12 = { index = "nvidia" }
|
cuml-cu12 = { index = "nvidia" }
|
||||||
|
|
@ -136,3 +141,4 @@ cudnn = ">=9.13.1.26,<10"
|
||||||
cusparselt = ">=0.8.1.1,<0.9"
|
cusparselt = ">=0.8.1.1,<0.9"
|
||||||
cuda-version = "12.9.*"
|
cuda-version = "12.9.*"
|
||||||
rapids = ">=25.10.0,<26"
|
rapids = ">=25.10.0,<26"
|
||||||
|
healpix-geo = ">=0.0.6"
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
#! /bin/bash
|
#! /bin/bash
|
||||||
|
|
||||||
pixi run darts --grid hex --level 3
|
pixi run darts extract_darts_mllabels --grid hex --level 3
|
||||||
pixi run darts --grid hex --level 4
|
pixi run darts extract_darts_mllabels --grid hex --level 4
|
||||||
pixi run darts --grid hex --level 5
|
pixi run darts extract_darts_mllabels --grid hex --level 5
|
||||||
pixi run darts --grid hex --level 6
|
pixi run darts extract_darts_mllabels --grid hex --level 6
|
||||||
pixi run darts --grid healpix --level 6
|
pixi run darts extract_darts_mllabels --grid healpix --level 6
|
||||||
pixi run darts --grid healpix --level 7
|
pixi run darts extract_darts_mllabels --grid healpix --level 7
|
||||||
pixi run darts --grid healpix --level 8
|
pixi run darts extract_darts_mllabels --grid healpix --level 8
|
||||||
pixi run darts --grid healpix --level 9
|
pixi run darts extract_darts_mllabels --grid healpix --level 9
|
||||||
pixi run darts --grid healpix --level 10
|
pixi run darts extract_darts_mllabels --grid healpix --level 10
|
||||||
|
|
@ -1,21 +1,21 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
pixi run alpha-earth download --grid hex --level 3
|
# pixi run alpha-earth download --grid hex --level 3
|
||||||
pixi run alpha-earth download --grid hex --level 4
|
# pixi run alpha-earth download --grid hex --level 4
|
||||||
pixi run alpha-earth download --grid hex --level 5
|
# pixi run alpha-earth download --grid hex --level 5
|
||||||
pixi run alpha-earth download --grid hex --level 6
|
# pixi run alpha-earth download --grid hex --level 6
|
||||||
pixi run alpha-earth download --grid healpix --level 6
|
# pixi run alpha-earth download --grid healpix --level 6
|
||||||
pixi run alpha-earth download --grid healpix --level 7
|
# pixi run alpha-earth download --grid healpix --level 7
|
||||||
pixi run alpha-earth download --grid healpix --level 8
|
# pixi run alpha-earth download --grid healpix --level 8
|
||||||
pixi run alpha-earth download --grid healpix --level 9
|
# pixi run alpha-earth download --grid healpix --level 9
|
||||||
pixi run alpha-earth download --grid healpix --level 10
|
# pixi run alpha-earth download --grid healpix --level 10
|
||||||
|
|
||||||
pixi run alpha-earth combine-to-zarr --grid hex --level 3
|
pixi run alpha-earth combine-to-zarr --grid hex --level 3
|
||||||
pixi run alpha-earth combine-to-zarr --grid hex --level 4
|
pixi run alpha-earth combine-to-zarr --grid hex --level 4
|
||||||
pixi run alpha-earth combine-to-zarr --grid hex --level 5
|
pixi run alpha-earth combine-to-zarr --grid hex --level 5
|
||||||
pixi run alpha-earth combine-to-zarr --grid hex --level 6
|
# pixi run alpha-earth combine-to-zarr --grid hex --level 6
|
||||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 6
|
pixi run alpha-earth combine-to-zarr --grid healpix --level 6
|
||||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 7
|
pixi run alpha-earth combine-to-zarr --grid healpix --level 7
|
||||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
|
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
|
||||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
|
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
|
||||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 10
|
# pixi run alpha-earth combine-to-zarr --grid healpix --level 10
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,18 @@ Date: October 2025
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Generator
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import cudf
|
||||||
|
import cuml.cluster
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import ee
|
import ee
|
||||||
import geemap
|
import geemap
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import sklearn.cluster
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
import xdggs
|
import xdggs
|
||||||
from rich import pretty, print, traceback
|
from rich import pretty, print, traceback
|
||||||
|
|
@ -33,6 +37,22 @@ cli = cyclopts.App(name="alpha-earth")
|
||||||
# 7454521782,230147807,10000000.
|
# 7454521782,230147807,10000000.
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[pd.DataFrame]:
|
||||||
|
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
||||||
|
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
||||||
|
|
||||||
|
# use cuml and cudf if len of centroids is larger than 100000
|
||||||
|
if len(centroids) > 100000:
|
||||||
|
print(f"Using cuML KMeans for partitioning {len(centroids)} centroids")
|
||||||
|
centroids_cudf = cudf.DataFrame.from_pandas(centroids)
|
||||||
|
kmeans = cuml.cluster.KMeans(n_clusters=n_partitions, random_state=42)
|
||||||
|
labels = kmeans.fit_predict(centroids_cudf).to_pandas().to_numpy()
|
||||||
|
else:
|
||||||
|
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
|
||||||
|
for i in range(n_partitions):
|
||||||
|
yield grid_gdf.iloc[np.where(labels == i)[0]]
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def download(grid: Literal["hex", "healpix"], level: int):
|
def download(grid: Literal["hex", "healpix"], level: int):
|
||||||
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
|
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
|
||||||
|
|
@ -77,11 +97,11 @@ def download(grid: Literal["hex", "healpix"], level: int):
|
||||||
return feature.set(mean_dict)
|
return feature.set(mean_dict)
|
||||||
|
|
||||||
# Process grid in batches of 100
|
# Process grid in batches of 100
|
||||||
batch_size = 50
|
batch_size = 25
|
||||||
all_results = []
|
all_results = []
|
||||||
n_batches = len(grid_gdf) // batch_size
|
n_batches = len(grid_gdf) // batch_size
|
||||||
for batch_num, batch_grid in track(
|
for batch_grid in track(
|
||||||
enumerate(np.array_split(grid_gdf, n_batches)),
|
_batch_grid(grid_gdf, n_batches),
|
||||||
description="Processing batches...",
|
description="Processing batches...",
|
||||||
total=n_batches,
|
total=n_batches,
|
||||||
):
|
):
|
||||||
|
|
@ -121,6 +141,7 @@ def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
|
||||||
np.nan,
|
np.nan,
|
||||||
dims=("year", "cell_ids", "band", "agg"),
|
dims=("year", "cell_ids", "band", "agg"),
|
||||||
coords={"year": years, "cell_ids": cell_ids, "band": bands, "agg": aggs},
|
coords={"year": years, "cell_ids": cell_ids, "band": bands, "agg": aggs},
|
||||||
|
name="embeddings",
|
||||||
).astype(np.float32)
|
).astype(np.float32)
|
||||||
|
|
||||||
# ? These attributes are needed for xdggs
|
# ? These attributes are needed for xdggs
|
||||||
|
|
@ -141,8 +162,9 @@ def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
|
||||||
|
|
||||||
a = xdggs.decode(a)
|
a = xdggs.decode(a)
|
||||||
|
|
||||||
|
a = a.chunk({"year": 1, "cell_ids": 100000, "band": 64, "agg": 12})
|
||||||
zarr_path = get_embeddings_store(grid, level)
|
zarr_path = get_embeddings_store(grid, level)
|
||||||
a.to_zarr(zarr_path, consolidated=False, mode="w", encoding=codecs.from_ds(a))
|
a.to_zarr(zarr_path, consolidated=False, mode="w", encoding=codecs.from_da(a))
|
||||||
print(f"Saved combined embeddings to {zarr_path}.")
|
print(f"Saved combined embeddings to {zarr_path}.")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,3 +36,24 @@ def from_ds(
|
||||||
if ds[var].dtype == "float64":
|
if ds[var].dtype == "float64":
|
||||||
encoding[var]["dtype"] = "float32"
|
encoding[var]["dtype"] = "float32"
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
|
|
||||||
|
def from_da(da: xr.DataArray, store_floats_as_float32: bool = True) -> dict:
|
||||||
|
"""Create compression encoding for zarr DataArray storage.
|
||||||
|
|
||||||
|
Creates Blosc compression configuration for the DataArray
|
||||||
|
using zstd compression with level 5.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
da (xr.DataArray): The xarray DataArray to create encoding for.
|
||||||
|
store_floats_as_float32 (bool, optional): Whether to store floating point data as float32.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Encoding dictionary with compression settings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
encoding = {"compressors": BloscCodec(cname="zstd", clevel=5)}
|
||||||
|
if store_floats_as_float32 and da.dtype == "float64":
|
||||||
|
encoding["dtype"] = "float32"
|
||||||
|
return {da.name: encoding}
|
||||||
|
|
|
||||||
|
|
@ -10,17 +10,21 @@ from typing import Literal
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
|
import pandas as pd
|
||||||
from rich import pretty, print, traceback
|
from rich import pretty, print, traceback
|
||||||
from rich.progress import track
|
from rich.progress import track
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice import grids
|
from entropice import grids
|
||||||
from entropice.paths import dartsl2_cov_file, dartsl2_file, get_darts_rts_file
|
from entropice.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
||||||
|
cli = cyclopts.App(name="darts-rts")
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
||||||
"""Extract RTS labels from DARTS dataset.
|
"""Extract RTS labels from DARTS dataset.
|
||||||
|
|
||||||
|
|
@ -81,8 +85,61 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
||||||
stopwatch.summary()
|
stopwatch.summary()
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int):
|
||||||
|
with stopwatch("Load data"):
|
||||||
|
grid_gdf = grids.open(grid, level)
|
||||||
|
darts_mllabels = (
|
||||||
|
gpd.GeoDataFrame(
|
||||||
|
pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/TrainingLabel*.gpkg")])
|
||||||
|
)
|
||||||
|
.reset_index(drop=True)
|
||||||
|
.to_crs(grid_gdf.crs)
|
||||||
|
)
|
||||||
|
# Filter out invalid labels
|
||||||
|
darts_mllabels = darts_mllabels[darts_mllabels.geometry.is_valid]
|
||||||
|
# Remove overlapping labels by dissolving
|
||||||
|
darts_mllabels = darts_mllabels[["geometry"]].dissolve().explode()
|
||||||
|
|
||||||
|
darts_cov_mllabels = (
|
||||||
|
gpd.GeoDataFrame(
|
||||||
|
pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/ImageFootprints*.gpkg")])
|
||||||
|
)
|
||||||
|
.reset_index(drop=True)
|
||||||
|
.to_crs(grid_gdf.crs)
|
||||||
|
)
|
||||||
|
darts_cov_mllabels = darts_cov_mllabels[["geometry"]].dissolve().explode()
|
||||||
|
with stopwatch("Extract RTS labels"):
|
||||||
|
grid_mllabels = grid_gdf.overlay(darts_mllabels.to_crs(grid_gdf.crs), how="intersection")
|
||||||
|
grid_cov_mllabels = grid_gdf.overlay(darts_cov_mllabels.to_crs(grid_gdf.crs), how="intersection")
|
||||||
|
|
||||||
|
with stopwatch("Processing RTS"):
|
||||||
|
counts = grid_mllabels.groupby("cell_id").size()
|
||||||
|
grid_gdf["dartsml_rts_count"] = grid_gdf.cell_id.map(counts)
|
||||||
|
|
||||||
|
areas = grid_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
|
||||||
|
grid_gdf["dartsml_rts_area"] = grid_gdf.cell_id.map(areas)
|
||||||
|
|
||||||
|
areas_cov = grid_cov_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
|
||||||
|
grid_gdf["dartsml_covered_area"] = grid_gdf.cell_id.map(areas_cov)
|
||||||
|
grid_gdf["dartsml_coverage"] = grid_gdf["dartsml_covered_area"] / grid_gdf.geometry.area
|
||||||
|
grid_gdf["dartsml_rts_density"] = grid_gdf["dartsml_rts_area"] / grid_gdf["dartsml_covered_area"]
|
||||||
|
|
||||||
|
# Apply corrections to NaNs
|
||||||
|
covered = ~grid_gdf["dartsml_coverage"].isna()
|
||||||
|
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
|
||||||
|
|
||||||
|
grid_gdf["darts_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna()
|
||||||
|
grid_gdf["darts_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna()
|
||||||
|
|
||||||
|
output_path = get_darts_rts_file(grid, level, labels=True)
|
||||||
|
grid_gdf.to_parquet(output_path)
|
||||||
|
print(f"Saved RTS labels to {output_path}")
|
||||||
|
stopwatch.summary()
|
||||||
|
|
||||||
|
|
||||||
def main(): # noqa: D103
|
def main(): # noqa: D103
|
||||||
cyclopts.run(extract_darts_rts)
|
cli()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,19 @@
|
||||||
"""Training dataset preparation and model training."""
|
"""Training dataset preparation and model training.
|
||||||
|
|
||||||
|
Naming conventions:
|
||||||
|
|
||||||
|
- Ensemble Dataset
|
||||||
|
- Member-Datasets: selection based on config
|
||||||
|
- L2-Datasets: ready to use XDGGS Xarray Datasets
|
||||||
|
- Currently implemented: embeddings (alpha-earth), ERA5-seasonal, ERA5-shoulder, ERA5-yearly, Arcticdem32m
|
||||||
|
- All L2 Datasets can be multidimensional, but all having at least the cell_ids dimension
|
||||||
|
- All L2 Datasets have at least one data variable
|
||||||
|
- Dimensions of L2 Datasets are e.g. time or aggregation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cyclopts
|
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
@ -11,12 +22,7 @@ from rich import pretty, traceback
|
||||||
from sklearn import set_config
|
from sklearn import set_config
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice.paths import (
|
import entropice.paths
|
||||||
get_darts_rts_file,
|
|
||||||
get_embeddings_store,
|
|
||||||
get_era5_stores,
|
|
||||||
get_train_dataset_file,
|
|
||||||
)
|
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -26,105 +32,182 @@ set_config(array_api_dispatch=True)
|
||||||
sns.set_theme("talk", "whitegrid")
|
sns.set_theme("talk", "whitegrid")
|
||||||
|
|
||||||
|
|
||||||
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
|
||||||
seasons = {10: "winter", 4: "summer"}
|
|
||||||
|
|
||||||
|
|
||||||
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
|
|
||||||
def _prep_era5(
|
|
||||||
rts: gpd.GeoDataFrame,
|
|
||||||
temporal: Literal["yearly", "seasonal", "shoulder"],
|
|
||||||
grid: Literal["hex", "healpix"],
|
|
||||||
level: int,
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
era5_df = []
|
|
||||||
era5_store = get_era5_stores(temporal, grid=grid, level=level)
|
|
||||||
era5 = xr.open_zarr(era5_store, consolidated=False)
|
|
||||||
era5 = era5.sel(cell_ids=rts["cell_id"].values)
|
|
||||||
|
|
||||||
for var in era5.data_vars:
|
|
||||||
df = era5[var].drop_vars("spatial_ref").to_dataframe()
|
|
||||||
if temporal == "yearly":
|
if temporal == "yearly":
|
||||||
df["t"] = df.index.get_level_values("time").year
|
return df.index.get_level_values("time").year
|
||||||
elif temporal == "seasonal":
|
elif temporal == "seasonal":
|
||||||
df["t"] = (
|
seasons = {10: "winter", 4: "summer"}
|
||||||
|
return (
|
||||||
df.index.get_level_values("time")
|
df.index.get_level_values("time")
|
||||||
.month.map(lambda x: seasons.get(x))
|
.month.map(lambda x: seasons.get(x))
|
||||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||||
)
|
)
|
||||||
elif temporal == "shoulder":
|
elif temporal == "shoulder":
|
||||||
df["t"] = (
|
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
||||||
|
return (
|
||||||
df.index.get_level_values("time")
|
df.index.get_level_values("time")
|
||||||
.month.map(lambda x: shoulder_seasons.get(x))
|
.month.map(lambda x: shoulder_seasons.get(x))
|
||||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||||
)
|
)
|
||||||
df = (
|
|
||||||
df.pivot_table(index="cell_ids", columns="t", values=var)
|
|
||||||
.rename(columns=lambda x: f"{var}_{x}")
|
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||||
.rename_axis(None, axis=1)
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetEnsemble:
|
||||||
|
grid: Literal["hex", "healpix"]
|
||||||
|
level: int
|
||||||
|
target: Literal["darts_rts", "darts_mllabels"]
|
||||||
|
members: list[L2Dataset]
|
||||||
|
dimension_filters: dict[L2Dataset, dict[str, list]] = field(default_factory=dict)
|
||||||
|
variable_filters: dict[L2Dataset, list[str]] = field(default_factory=dict)
|
||||||
|
filter_target: str | Literal[False] = False
|
||||||
|
add_lonlat: bool = True
|
||||||
|
|
||||||
|
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
||||||
|
if member == "AlphaEarth":
|
||||||
|
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
||||||
|
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||||
|
store = entropice.paths.get_era5_stores(member.split("-")[1], grid=self.grid, level=self.level)
|
||||||
|
elif member == "ArcticDEM":
|
||||||
|
store = entropice.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Member {member} not implemented.")
|
||||||
|
|
||||||
|
ds = xr.open_zarr(store, consolidated=False)
|
||||||
|
|
||||||
|
# Apply variable filters
|
||||||
|
if member in self.variable_filters:
|
||||||
|
assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, (
|
||||||
|
f"Invalid variable filter for {member=}: {self.variable_filters[member]}"
|
||||||
|
" Variable filter values must be a list with one or more entries."
|
||||||
)
|
)
|
||||||
era5_df.append(df)
|
ds = ds[self.variable_filters[member]]
|
||||||
era5_df = pd.concat(era5_df, axis=1)
|
|
||||||
era5_df = era5_df.rename(columns={col: f"era5_{col}" for col in era5_df.columns if col != "cell_id"})
|
|
||||||
return era5_df
|
|
||||||
|
|
||||||
|
# Apply dimension filters
|
||||||
@stopwatch("Prepare embeddings data")
|
if member in self.dimension_filters:
|
||||||
def _prep_embeddings(rts: gpd.GeoDataFrame, grid: Literal["hex", "healpix"], level: int) -> pd.DataFrame:
|
for dim, values in self.dimension_filters[member].items():
|
||||||
embs_store = get_embeddings_store(grid=grid, level=level)
|
assert isinstance(values, list) and len(values) >= 1, (
|
||||||
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
|
f"Invalid dimension filter for {dim=}: {values}"
|
||||||
embeddings = embeddings.sel(cell=rts["cell_id"].values)
|
" Dimension filter values must be a list with one or more entries."
|
||||||
|
|
||||||
embeddings_df = embeddings.to_dataframe(name="value")
|
|
||||||
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
|
|
||||||
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
|
||||||
|
|
||||||
embeddings_df = embeddings_df.rename(
|
|
||||||
columns={col: f"embeddings_{col}" for col in embeddings_df.columns if col != "cell_id"}
|
|
||||||
)
|
)
|
||||||
return embeddings_df
|
ds = ds.sel({dim: values})
|
||||||
|
|
||||||
|
# Delete all coordinates which are not in the dimension
|
||||||
|
for coord in ds.coords:
|
||||||
|
if coord not in ds.dims:
|
||||||
|
ds = ds.drop_vars(coord)
|
||||||
|
|
||||||
def prepare_dataset(grid: Literal["hex", "healpix"], level: int, filter_target: bool = False):
|
# Only load target cell ids
|
||||||
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
|
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(targets["cell_id"].values))
|
||||||
|
ds = ds.sel(cell_ids=list(intersecting_cell_ids))
|
||||||
|
|
||||||
Args:
|
# Actually read data into memory
|
||||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
if not lazy:
|
||||||
level (int): The grid level to use.
|
ds.load()
|
||||||
|
return ds
|
||||||
|
|
||||||
|
@stopwatch("Loading targets")
|
||||||
|
def _read_target(self) -> gpd.GeoDataFrame:
|
||||||
|
if self.target == "darts_rts":
|
||||||
|
target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level)
|
||||||
|
elif self.target == "darts_mllabels":
|
||||||
|
target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level, labels=True)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Target {self.target} not implemented.")
|
||||||
|
targets = gpd.read_parquet(target_store)
|
||||||
|
|
||||||
"""
|
|
||||||
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
|
|
||||||
# Filter to coverage
|
# Filter to coverage
|
||||||
if filter_target:
|
if self.filter_target:
|
||||||
rts = rts[rts["darts_has_coverage"]]
|
targets = targets[targets[self.filter_target]]
|
||||||
# Convert hex cell_id to int
|
# Convert hex cell_id to int
|
||||||
if grid == "hex":
|
if self.grid == "hex":
|
||||||
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
|
targets["cell_id"] = targets["cell_id"].apply(lambda x: int(x, 16))
|
||||||
|
|
||||||
# Add the lat / lon of the cell centers
|
# Add the lat / lon of the cell centers
|
||||||
rts["lon"] = rts.geometry.centroid.x
|
if self.add_lonlat:
|
||||||
rts["lat"] = rts.geometry.centroid.y
|
targets["lon"] = targets.geometry.centroid.x
|
||||||
|
targets["lat"] = targets.geometry.centroid.y
|
||||||
|
|
||||||
# Get era5 data
|
return targets
|
||||||
era5_yearly = _prep_era5(rts, "yearly", grid, level)
|
|
||||||
era5_seasonal = _prep_era5(rts, "seasonal", grid, level)
|
|
||||||
era5_shoulder = _prep_era5(rts, "shoulder", grid, level)
|
|
||||||
|
|
||||||
# Get embeddings data
|
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
|
||||||
embeddings = _prep_embeddings(rts, grid, level)
|
def _prep_era5(
|
||||||
|
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
era5_df = []
|
||||||
|
era5 = self._read_member(f"ERA5-{temporal}", targets)
|
||||||
|
|
||||||
|
for var in era5.data_vars:
|
||||||
|
df = era5[var].to_dataframe()
|
||||||
|
df["t"] = _get_era5_tempus(df, temporal)
|
||||||
|
# If aggregations is not in dims, we can pivot directly
|
||||||
|
if "aggregations" not in era5.dims:
|
||||||
|
df = (
|
||||||
|
df.pivot_table(index="cell_ids", columns="t", values=var)
|
||||||
|
.rename(columns=lambda x: f"era5_{var}_{x}")
|
||||||
|
.rename_axis(None, axis=1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
df = df.pivot_table(index="cell_ids", columns=["t", "aggregations"], values=var)
|
||||||
|
df.columns = [f"era5_{var}_{t}_{agg}" for t, agg in df.columns]
|
||||||
|
era5_df.append(df)
|
||||||
|
era5_df = pd.concat(era5_df, axis=1)
|
||||||
|
return era5_df
|
||||||
|
|
||||||
|
@stopwatch("Prepare embeddings data")
|
||||||
|
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
||||||
|
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
|
||||||
|
|
||||||
|
embeddings_df = embeddings.to_dataframe(name="value")
|
||||||
|
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
|
||||||
|
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||||
|
return embeddings_df
|
||||||
|
|
||||||
|
@stopwatch("Prepare arcticdem data")
|
||||||
|
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
||||||
|
arcticdem = self._read_member("ArcticDEM", targets)
|
||||||
|
|
||||||
|
arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations")
|
||||||
|
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
|
||||||
|
return arcticdem_df
|
||||||
|
|
||||||
|
def print_stats(self):
|
||||||
|
targets = self._read_target()
|
||||||
|
print(f"=== Target: {self.target}")
|
||||||
|
print(f"\tNumber of target samples: {len(targets)}")
|
||||||
|
|
||||||
|
n_cols = 2 if self.add_lonlat else 0 # Lat and Lon
|
||||||
|
for member in self.members:
|
||||||
|
ds = self._read_member(member, targets, lazy=True)
|
||||||
|
print(f"=== Member: {member}")
|
||||||
|
print(f"\tVariables ({len(ds.data_vars)}): {list(ds.data_vars)}")
|
||||||
|
print(f"\tDimensions: {dict(ds.sizes)}")
|
||||||
|
print(f"\tCoordinates: {list(ds.coords)}")
|
||||||
|
n_cols_member = len(ds.data_vars)
|
||||||
|
for dim in ds.sizes:
|
||||||
|
if dim != "cell_ids":
|
||||||
|
n_cols_member *= ds.sizes[dim]
|
||||||
|
print(f"\tNumber of features from member: {n_cols_member}")
|
||||||
|
n_cols += n_cols_member
|
||||||
|
print(f"=== Total number of features in dataset: {n_cols}")
|
||||||
|
|
||||||
|
def create(self) -> pd.DataFrame:
|
||||||
|
targets = self._read_target()
|
||||||
|
|
||||||
|
member_dfs = []
|
||||||
|
for member in self.members:
|
||||||
|
if member.startswith("ERA5"):
|
||||||
|
member_dfs.append(self._prep_era5(targets, member.split("-")[1]))
|
||||||
|
elif member == "AlphaEarth":
|
||||||
|
member_dfs.append(self._prep_embeddings(targets))
|
||||||
|
elif member == "ArcticDEM":
|
||||||
|
member_dfs.append(self._prep_arcticdem(targets))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Member {member} not implemented.")
|
||||||
|
|
||||||
# Combine datasets by cell id / cell
|
|
||||||
with stopwatch("Combine datasets"):
|
with stopwatch("Combine datasets"):
|
||||||
dataset = rts.set_index("cell_id").join(era5_yearly).join(era5_seasonal).join(era5_shoulder).join(embeddings)
|
dataset = targets.set_index("cell_id").join(member_dfs)
|
||||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||||
|
return dataset
|
||||||
dataset_file = get_train_dataset_file(grid=grid, level=level)
|
|
||||||
dataset.reset_index().to_parquet(dataset_file)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
cyclopts.run(prepare_dataset)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
||||||
|
|
||||||
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
||||||
dartsl2_cov_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
dartsl2_cov_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
||||||
|
darts_ml_training_labels_repo = DARTS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
||||||
|
|
||||||
|
|
||||||
def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str:
|
def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str:
|
||||||
|
|
@ -52,8 +53,11 @@ def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
return vizfile
|
return vizfile
|
||||||
|
|
||||||
|
|
||||||
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
|
if labels:
|
||||||
|
rtsfile = DARTS_DIR / f"{gridname}_darts-mllabels.parquet"
|
||||||
|
else:
|
||||||
rtsfile = DARTS_DIR / f"{gridname}_darts.parquet"
|
rtsfile = DARTS_DIR / f"{gridname}_darts.parquet"
|
||||||
return rtsfile
|
return rtsfile
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue