Create Ensemble Datasets
This commit is contained in:
parent
33c9667383
commit
67030c9f0d
10 changed files with 839 additions and 626 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -25,3 +25,6 @@ playground.ipynb
|
|||
# pixi environments
|
||||
.pixi
|
||||
*.egg-info
|
||||
|
||||
# Disable all notebook for now
|
||||
notebooks/*.ipynb
|
||||
|
|
@ -45,7 +45,7 @@ dependencies = [
|
|||
"ultraplot>=1.63.0",
|
||||
"xanimate",
|
||||
"xarray>=2025.9.0",
|
||||
"xdggs>=0.2.1",
|
||||
"xdggs",
|
||||
"xvec>=0.5.1",
|
||||
"zarr[remote]>=3.1.3",
|
||||
"geocube>=0.7.1,<0.8",
|
||||
|
|
@ -57,12 +57,16 @@ dependencies = [
|
|||
"xgboost>=3.1.1,<4",
|
||||
"s3fs>=2025.10.0,<2026",
|
||||
"xarray-spatial",
|
||||
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3", "antimeridian>=0.4.5,<0.5", "duckdb>=1.4.2,<2",
|
||||
"cupy-xarray>=0.1.4,<0.2",
|
||||
"memray>=1.19.1,<2",
|
||||
"xarray-histogram>=0.2.2,<0.3",
|
||||
"antimeridian>=0.4.5,<0.5",
|
||||
"duckdb>=1.4.2,<2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
create-grid = "entropice.grids:main"
|
||||
darts = "entropice.darts:main"
|
||||
darts = "entropice.darts:cli"
|
||||
alpha-earth = "entropice.alphaearth:main"
|
||||
era5 = "entropice.era5:cli"
|
||||
arcticdem = "entropice.arcticdem:cli"
|
||||
|
|
@ -86,6 +90,7 @@ entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "re
|
|||
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
|
||||
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
||||
xdem = { git = "https://github.com/GlacioHack/xdem" }
|
||||
xdggs = { git = "https://github.com/relativityhd/xdggs", branch = "feature/make-plotting-useful" }
|
||||
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
|
||||
cudf-cu12 = { index = "nvidia" }
|
||||
cuml-cu12 = { index = "nvidia" }
|
||||
|
|
@ -136,3 +141,4 @@ cudnn = ">=9.13.1.26,<10"
|
|||
cusparselt = ">=0.8.1.1,<0.9"
|
||||
cuda-version = "12.9.*"
|
||||
rapids = ">=25.10.0,<26"
|
||||
healpix-geo = ">=0.0.6"
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
#! /bin/bash
|
||||
|
||||
pixi run darts --grid hex --level 3
|
||||
pixi run darts --grid hex --level 4
|
||||
pixi run darts --grid hex --level 5
|
||||
pixi run darts --grid hex --level 6
|
||||
pixi run darts --grid healpix --level 6
|
||||
pixi run darts --grid healpix --level 7
|
||||
pixi run darts --grid healpix --level 8
|
||||
pixi run darts --grid healpix --level 9
|
||||
pixi run darts --grid healpix --level 10
|
||||
pixi run darts extract_darts_mllabels --grid hex --level 3
|
||||
pixi run darts extract_darts_mllabels --grid hex --level 4
|
||||
pixi run darts extract_darts_mllabels --grid hex --level 5
|
||||
pixi run darts extract_darts_mllabels --grid hex --level 6
|
||||
pixi run darts extract_darts_mllabels --grid healpix --level 6
|
||||
pixi run darts extract_darts_mllabels --grid healpix --level 7
|
||||
pixi run darts extract_darts_mllabels --grid healpix --level 8
|
||||
pixi run darts extract_darts_mllabels --grid healpix --level 9
|
||||
pixi run darts extract_darts_mllabels --grid healpix --level 10
|
||||
|
|
@ -1,21 +1,21 @@
|
|||
#!/bin/bash
|
||||
|
||||
pixi run alpha-earth download --grid hex --level 3
|
||||
pixi run alpha-earth download --grid hex --level 4
|
||||
pixi run alpha-earth download --grid hex --level 5
|
||||
pixi run alpha-earth download --grid hex --level 6
|
||||
pixi run alpha-earth download --grid healpix --level 6
|
||||
pixi run alpha-earth download --grid healpix --level 7
|
||||
pixi run alpha-earth download --grid healpix --level 8
|
||||
pixi run alpha-earth download --grid healpix --level 9
|
||||
pixi run alpha-earth download --grid healpix --level 10
|
||||
# pixi run alpha-earth download --grid hex --level 3
|
||||
# pixi run alpha-earth download --grid hex --level 4
|
||||
# pixi run alpha-earth download --grid hex --level 5
|
||||
# pixi run alpha-earth download --grid hex --level 6
|
||||
# pixi run alpha-earth download --grid healpix --level 6
|
||||
# pixi run alpha-earth download --grid healpix --level 7
|
||||
# pixi run alpha-earth download --grid healpix --level 8
|
||||
# pixi run alpha-earth download --grid healpix --level 9
|
||||
# pixi run alpha-earth download --grid healpix --level 10
|
||||
|
||||
pixi run alpha-earth combine-to-zarr --grid hex --level 3
|
||||
pixi run alpha-earth combine-to-zarr --grid hex --level 4
|
||||
pixi run alpha-earth combine-to-zarr --grid hex --level 5
|
||||
pixi run alpha-earth combine-to-zarr --grid hex --level 6
|
||||
# pixi run alpha-earth combine-to-zarr --grid hex --level 6
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 6
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 7
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
|
||||
pixi run alpha-earth combine-to-zarr --grid healpix --level 10
|
||||
# pixi run alpha-earth combine-to-zarr --grid healpix --level 10
|
||||
|
|
|
|||
|
|
@ -5,14 +5,18 @@ Date: October 2025
|
|||
"""
|
||||
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from typing import Literal
|
||||
|
||||
import cudf
|
||||
import cuml.cluster
|
||||
import cyclopts
|
||||
import ee
|
||||
import geemap
|
||||
import geopandas as gpd
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sklearn.cluster
|
||||
import xarray as xr
|
||||
import xdggs
|
||||
from rich import pretty, print, traceback
|
||||
|
|
@ -33,6 +37,22 @@ cli = cyclopts.App(name="alpha-earth")
|
|||
# 7454521782,230147807,10000000.
|
||||
|
||||
|
||||
def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[pd.DataFrame]:
|
||||
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
|
||||
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
|
||||
|
||||
# use cuml and cudf if len of centroids is larger than 100000
|
||||
if len(centroids) > 100000:
|
||||
print(f"Using cuML KMeans for partitioning {len(centroids)} centroids")
|
||||
centroids_cudf = cudf.DataFrame.from_pandas(centroids)
|
||||
kmeans = cuml.cluster.KMeans(n_clusters=n_partitions, random_state=42)
|
||||
labels = kmeans.fit_predict(centroids_cudf).to_pandas().to_numpy()
|
||||
else:
|
||||
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
|
||||
for i in range(n_partitions):
|
||||
yield grid_gdf.iloc[np.where(labels == i)[0]]
|
||||
|
||||
|
||||
@cli.command()
|
||||
def download(grid: Literal["hex", "healpix"], level: int):
|
||||
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
|
||||
|
|
@ -77,11 +97,11 @@ def download(grid: Literal["hex", "healpix"], level: int):
|
|||
return feature.set(mean_dict)
|
||||
|
||||
# Process grid in batches of 100
|
||||
batch_size = 50
|
||||
batch_size = 25
|
||||
all_results = []
|
||||
n_batches = len(grid_gdf) // batch_size
|
||||
for batch_num, batch_grid in track(
|
||||
enumerate(np.array_split(grid_gdf, n_batches)),
|
||||
for batch_grid in track(
|
||||
_batch_grid(grid_gdf, n_batches),
|
||||
description="Processing batches...",
|
||||
total=n_batches,
|
||||
):
|
||||
|
|
@ -121,6 +141,7 @@ def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
|
|||
np.nan,
|
||||
dims=("year", "cell_ids", "band", "agg"),
|
||||
coords={"year": years, "cell_ids": cell_ids, "band": bands, "agg": aggs},
|
||||
name="embeddings",
|
||||
).astype(np.float32)
|
||||
|
||||
# ? These attributes are needed for xdggs
|
||||
|
|
@ -141,8 +162,9 @@ def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
|
|||
|
||||
a = xdggs.decode(a)
|
||||
|
||||
a = a.chunk({"year": 1, "cell_ids": 100000, "band": 64, "agg": 12})
|
||||
zarr_path = get_embeddings_store(grid, level)
|
||||
a.to_zarr(zarr_path, consolidated=False, mode="w", encoding=codecs.from_ds(a))
|
||||
a.to_zarr(zarr_path, consolidated=False, mode="w", encoding=codecs.from_da(a))
|
||||
print(f"Saved combined embeddings to {zarr_path}.")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -36,3 +36,24 @@ def from_ds(
|
|||
if ds[var].dtype == "float64":
|
||||
encoding[var]["dtype"] = "float32"
|
||||
return encoding
|
||||
|
||||
|
||||
def from_da(da: xr.DataArray, store_floats_as_float32: bool = True) -> dict:
|
||||
"""Create compression encoding for zarr DataArray storage.
|
||||
|
||||
Creates Blosc compression configuration for the DataArray
|
||||
using zstd compression with level 5.
|
||||
|
||||
Args:
|
||||
da (xr.DataArray): The xarray DataArray to create encoding for.
|
||||
store_floats_as_float32 (bool, optional): Whether to store floating point data as float32.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
dict: Encoding dictionary with compression settings.
|
||||
|
||||
"""
|
||||
encoding = {"compressors": BloscCodec(cname="zstd", clevel=5)}
|
||||
if store_floats_as_float32 and da.dtype == "float64":
|
||||
encoding["dtype"] = "float32"
|
||||
return {da.name: encoding}
|
||||
|
|
|
|||
|
|
@ -10,17 +10,21 @@ from typing import Literal
|
|||
|
||||
import cyclopts
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
from rich import pretty, print, traceback
|
||||
from rich.progress import track
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice import grids
|
||||
from entropice.paths import dartsl2_cov_file, dartsl2_file, get_darts_rts_file
|
||||
from entropice.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
cli = cyclopts.App(name="darts-rts")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
||||
"""Extract RTS labels from DARTS dataset.
|
||||
|
||||
|
|
@ -81,8 +85,61 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
|||
stopwatch.summary()
|
||||
|
||||
|
||||
@cli.command()
|
||||
def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int):
|
||||
with stopwatch("Load data"):
|
||||
grid_gdf = grids.open(grid, level)
|
||||
darts_mllabels = (
|
||||
gpd.GeoDataFrame(
|
||||
pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/TrainingLabel*.gpkg")])
|
||||
)
|
||||
.reset_index(drop=True)
|
||||
.to_crs(grid_gdf.crs)
|
||||
)
|
||||
# Filter out invalid labels
|
||||
darts_mllabels = darts_mllabels[darts_mllabels.geometry.is_valid]
|
||||
# Remove overlapping labels by dissolving
|
||||
darts_mllabels = darts_mllabels[["geometry"]].dissolve().explode()
|
||||
|
||||
darts_cov_mllabels = (
|
||||
gpd.GeoDataFrame(
|
||||
pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/ImageFootprints*.gpkg")])
|
||||
)
|
||||
.reset_index(drop=True)
|
||||
.to_crs(grid_gdf.crs)
|
||||
)
|
||||
darts_cov_mllabels = darts_cov_mllabels[["geometry"]].dissolve().explode()
|
||||
with stopwatch("Extract RTS labels"):
|
||||
grid_mllabels = grid_gdf.overlay(darts_mllabels.to_crs(grid_gdf.crs), how="intersection")
|
||||
grid_cov_mllabels = grid_gdf.overlay(darts_cov_mllabels.to_crs(grid_gdf.crs), how="intersection")
|
||||
|
||||
with stopwatch("Processing RTS"):
|
||||
counts = grid_mllabels.groupby("cell_id").size()
|
||||
grid_gdf["dartsml_rts_count"] = grid_gdf.cell_id.map(counts)
|
||||
|
||||
areas = grid_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
|
||||
grid_gdf["dartsml_rts_area"] = grid_gdf.cell_id.map(areas)
|
||||
|
||||
areas_cov = grid_cov_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
|
||||
grid_gdf["dartsml_covered_area"] = grid_gdf.cell_id.map(areas_cov)
|
||||
grid_gdf["dartsml_coverage"] = grid_gdf["dartsml_covered_area"] / grid_gdf.geometry.area
|
||||
grid_gdf["dartsml_rts_density"] = grid_gdf["dartsml_rts_area"] / grid_gdf["dartsml_covered_area"]
|
||||
|
||||
# Apply corrections to NaNs
|
||||
covered = ~grid_gdf["dartsml_coverage"].isna()
|
||||
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
|
||||
|
||||
grid_gdf["darts_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna()
|
||||
grid_gdf["darts_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna()
|
||||
|
||||
output_path = get_darts_rts_file(grid, level, labels=True)
|
||||
grid_gdf.to_parquet(output_path)
|
||||
print(f"Saved RTS labels to {output_path}")
|
||||
stopwatch.summary()
|
||||
|
||||
|
||||
def main(): # noqa: D103
|
||||
cyclopts.run(extract_darts_rts)
|
||||
cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,8 +1,19 @@
|
|||
"""Training dataset preparation and model training."""
|
||||
"""Training dataset preparation and model training.
|
||||
|
||||
Naming conventions:
|
||||
|
||||
- Ensemble Dataset
|
||||
- Member-Datasets: selection based on config
|
||||
- L2-Datasets: ready to use XDGGS Xarray Datasets
|
||||
- Currently implemented: embeddings (alpha-earth), ERA5-seasonal, ERA5-shoulder, ERA5-yearly, Arcticdem32m
|
||||
- All L2 Datasets can be multidimensional, but all having at least the cell_ids dimension
|
||||
- All L2 Datasets have at least one data variable
|
||||
- Dimensions of L2 Datasets are e.g. time or aggregation
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
import cyclopts
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
|
@ -11,12 +22,7 @@ from rich import pretty, traceback
|
|||
from sklearn import set_config
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.paths import (
|
||||
get_darts_rts_file,
|
||||
get_embeddings_store,
|
||||
get_era5_stores,
|
||||
get_train_dataset_file,
|
||||
)
|
||||
import entropice.paths
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
|
@ -26,105 +32,182 @@ set_config(array_api_dispatch=True)
|
|||
sns.set_theme("talk", "whitegrid")
|
||||
|
||||
|
||||
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
||||
seasons = {10: "winter", 4: "summer"}
|
||||
|
||||
|
||||
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
|
||||
def _prep_era5(
|
||||
rts: gpd.GeoDataFrame,
|
||||
temporal: Literal["yearly", "seasonal", "shoulder"],
|
||||
grid: Literal["hex", "healpix"],
|
||||
level: int,
|
||||
) -> pd.DataFrame:
|
||||
era5_df = []
|
||||
era5_store = get_era5_stores(temporal, grid=grid, level=level)
|
||||
era5 = xr.open_zarr(era5_store, consolidated=False)
|
||||
era5 = era5.sel(cell_ids=rts["cell_id"].values)
|
||||
|
||||
for var in era5.data_vars:
|
||||
df = era5[var].drop_vars("spatial_ref").to_dataframe()
|
||||
if temporal == "yearly":
|
||||
df["t"] = df.index.get_level_values("time").year
|
||||
elif temporal == "seasonal":
|
||||
df["t"] = (
|
||||
df.index.get_level_values("time")
|
||||
.month.map(lambda x: seasons.get(x))
|
||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||
)
|
||||
elif temporal == "shoulder":
|
||||
df["t"] = (
|
||||
df.index.get_level_values("time")
|
||||
.month.map(lambda x: shoulder_seasons.get(x))
|
||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||
)
|
||||
df = (
|
||||
df.pivot_table(index="cell_ids", columns="t", values=var)
|
||||
.rename(columns=lambda x: f"{var}_{x}")
|
||||
.rename_axis(None, axis=1)
|
||||
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
|
||||
if temporal == "yearly":
|
||||
return df.index.get_level_values("time").year
|
||||
elif temporal == "seasonal":
|
||||
seasons = {10: "winter", 4: "summer"}
|
||||
return (
|
||||
df.index.get_level_values("time")
|
||||
.month.map(lambda x: seasons.get(x))
|
||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||
)
|
||||
elif temporal == "shoulder":
|
||||
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
||||
return (
|
||||
df.index.get_level_values("time")
|
||||
.month.map(lambda x: shoulder_seasons.get(x))
|
||||
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
|
||||
)
|
||||
era5_df.append(df)
|
||||
era5_df = pd.concat(era5_df, axis=1)
|
||||
era5_df = era5_df.rename(columns={col: f"era5_{col}" for col in era5_df.columns if col != "cell_id"})
|
||||
return era5_df
|
||||
|
||||
|
||||
@stopwatch("Prepare embeddings data")
|
||||
def _prep_embeddings(rts: gpd.GeoDataFrame, grid: Literal["hex", "healpix"], level: int) -> pd.DataFrame:
|
||||
embs_store = get_embeddings_store(grid=grid, level=level)
|
||||
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
|
||||
embeddings = embeddings.sel(cell=rts["cell_id"].values)
|
||||
|
||||
embeddings_df = embeddings.to_dataframe(name="value")
|
||||
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
|
||||
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||
|
||||
embeddings_df = embeddings_df.rename(
|
||||
columns={col: f"embeddings_{col}" for col in embeddings_df.columns if col != "cell_id"}
|
||||
)
|
||||
return embeddings_df
|
||||
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
|
||||
|
||||
def prepare_dataset(grid: Literal["hex", "healpix"], level: int, filter_target: bool = False):
|
||||
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
|
||||
@dataclass
|
||||
class DatasetEnsemble:
|
||||
grid: Literal["hex", "healpix"]
|
||||
level: int
|
||||
target: Literal["darts_rts", "darts_mllabels"]
|
||||
members: list[L2Dataset]
|
||||
dimension_filters: dict[L2Dataset, dict[str, list]] = field(default_factory=dict)
|
||||
variable_filters: dict[L2Dataset, list[str]] = field(default_factory=dict)
|
||||
filter_target: str | Literal[False] = False
|
||||
add_lonlat: bool = True
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
||||
if member == "AlphaEarth":
|
||||
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||
store = entropice.paths.get_era5_stores(member.split("-")[1], grid=self.grid, level=self.level)
|
||||
elif member == "ArcticDEM":
|
||||
store = entropice.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
|
||||
else:
|
||||
raise NotImplementedError(f"Member {member} not implemented.")
|
||||
|
||||
"""
|
||||
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
|
||||
# Filter to coverage
|
||||
if filter_target:
|
||||
rts = rts[rts["darts_has_coverage"]]
|
||||
# Convert hex cell_id to int
|
||||
if grid == "hex":
|
||||
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
|
||||
ds = xr.open_zarr(store, consolidated=False)
|
||||
|
||||
# Add the lat / lon of the cell centers
|
||||
rts["lon"] = rts.geometry.centroid.x
|
||||
rts["lat"] = rts.geometry.centroid.y
|
||||
# Apply variable filters
|
||||
if member in self.variable_filters:
|
||||
assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, (
|
||||
f"Invalid variable filter for {member=}: {self.variable_filters[member]}"
|
||||
" Variable filter values must be a list with one or more entries."
|
||||
)
|
||||
ds = ds[self.variable_filters[member]]
|
||||
|
||||
# Get era5 data
|
||||
era5_yearly = _prep_era5(rts, "yearly", grid, level)
|
||||
era5_seasonal = _prep_era5(rts, "seasonal", grid, level)
|
||||
era5_shoulder = _prep_era5(rts, "shoulder", grid, level)
|
||||
# Apply dimension filters
|
||||
if member in self.dimension_filters:
|
||||
for dim, values in self.dimension_filters[member].items():
|
||||
assert isinstance(values, list) and len(values) >= 1, (
|
||||
f"Invalid dimension filter for {dim=}: {values}"
|
||||
" Dimension filter values must be a list with one or more entries."
|
||||
)
|
||||
ds = ds.sel({dim: values})
|
||||
|
||||
# Get embeddings data
|
||||
embeddings = _prep_embeddings(rts, grid, level)
|
||||
# Delete all coordinates which are not in the dimension
|
||||
for coord in ds.coords:
|
||||
if coord not in ds.dims:
|
||||
ds = ds.drop_vars(coord)
|
||||
|
||||
# Combine datasets by cell id / cell
|
||||
with stopwatch("Combine datasets"):
|
||||
dataset = rts.set_index("cell_id").join(era5_yearly).join(era5_seasonal).join(era5_shoulder).join(embeddings)
|
||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||
# Only load target cell ids
|
||||
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(targets["cell_id"].values))
|
||||
ds = ds.sel(cell_ids=list(intersecting_cell_ids))
|
||||
|
||||
dataset_file = get_train_dataset_file(grid=grid, level=level)
|
||||
dataset.reset_index().to_parquet(dataset_file)
|
||||
# Actually read data into memory
|
||||
if not lazy:
|
||||
ds.load()
|
||||
return ds
|
||||
|
||||
@stopwatch("Loading targets")
|
||||
def _read_target(self) -> gpd.GeoDataFrame:
|
||||
if self.target == "darts_rts":
|
||||
target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level)
|
||||
elif self.target == "darts_mllabels":
|
||||
target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level, labels=True)
|
||||
else:
|
||||
raise NotImplementedError(f"Target {self.target} not implemented.")
|
||||
targets = gpd.read_parquet(target_store)
|
||||
|
||||
def main():
|
||||
cyclopts.run(prepare_dataset)
|
||||
# Filter to coverage
|
||||
if self.filter_target:
|
||||
targets = targets[targets[self.filter_target]]
|
||||
# Convert hex cell_id to int
|
||||
if self.grid == "hex":
|
||||
targets["cell_id"] = targets["cell_id"].apply(lambda x: int(x, 16))
|
||||
|
||||
# Add the lat / lon of the cell centers
|
||||
if self.add_lonlat:
|
||||
targets["lon"] = targets.geometry.centroid.x
|
||||
targets["lat"] = targets.geometry.centroid.y
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
return targets
|
||||
|
||||
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
|
||||
def _prep_era5(
|
||||
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
|
||||
) -> pd.DataFrame:
|
||||
era5_df = []
|
||||
era5 = self._read_member(f"ERA5-{temporal}", targets)
|
||||
|
||||
for var in era5.data_vars:
|
||||
df = era5[var].to_dataframe()
|
||||
df["t"] = _get_era5_tempus(df, temporal)
|
||||
# If aggregations is not in dims, we can pivot directly
|
||||
if "aggregations" not in era5.dims:
|
||||
df = (
|
||||
df.pivot_table(index="cell_ids", columns="t", values=var)
|
||||
.rename(columns=lambda x: f"era5_{var}_{x}")
|
||||
.rename_axis(None, axis=1)
|
||||
)
|
||||
else:
|
||||
df = df.pivot_table(index="cell_ids", columns=["t", "aggregations"], values=var)
|
||||
df.columns = [f"era5_{var}_{t}_{agg}" for t, agg in df.columns]
|
||||
era5_df.append(df)
|
||||
era5_df = pd.concat(era5_df, axis=1)
|
||||
return era5_df
|
||||
|
||||
@stopwatch("Prepare embeddings data")
|
||||
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
||||
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
|
||||
|
||||
embeddings_df = embeddings.to_dataframe(name="value")
|
||||
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
|
||||
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||
return embeddings_df
|
||||
|
||||
@stopwatch("Prepare arcticdem data")
|
||||
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
||||
arcticdem = self._read_member("ArcticDEM", targets)
|
||||
|
||||
arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations")
|
||||
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
|
||||
return arcticdem_df
|
||||
|
||||
def print_stats(self):
|
||||
targets = self._read_target()
|
||||
print(f"=== Target: {self.target}")
|
||||
print(f"\tNumber of target samples: {len(targets)}")
|
||||
|
||||
n_cols = 2 if self.add_lonlat else 0 # Lat and Lon
|
||||
for member in self.members:
|
||||
ds = self._read_member(member, targets, lazy=True)
|
||||
print(f"=== Member: {member}")
|
||||
print(f"\tVariables ({len(ds.data_vars)}): {list(ds.data_vars)}")
|
||||
print(f"\tDimensions: {dict(ds.sizes)}")
|
||||
print(f"\tCoordinates: {list(ds.coords)}")
|
||||
n_cols_member = len(ds.data_vars)
|
||||
for dim in ds.sizes:
|
||||
if dim != "cell_ids":
|
||||
n_cols_member *= ds.sizes[dim]
|
||||
print(f"\tNumber of features from member: {n_cols_member}")
|
||||
n_cols += n_cols_member
|
||||
print(f"=== Total number of features in dataset: {n_cols}")
|
||||
|
||||
def create(self) -> pd.DataFrame:
|
||||
targets = self._read_target()
|
||||
|
||||
member_dfs = []
|
||||
for member in self.members:
|
||||
if member.startswith("ERA5"):
|
||||
member_dfs.append(self._prep_era5(targets, member.split("-")[1]))
|
||||
elif member == "AlphaEarth":
|
||||
member_dfs.append(self._prep_embeddings(targets))
|
||||
elif member == "ArcticDEM":
|
||||
member_dfs.append(self._prep_arcticdem(targets))
|
||||
else:
|
||||
raise NotImplementedError(f"Member {member} not implemented.")
|
||||
|
||||
with stopwatch("Combine datasets"):
|
||||
dataset = targets.set_index("cell_id").join(member_dfs)
|
||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||
return dataset
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
|||
|
||||
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
||||
dartsl2_cov_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
||||
darts_ml_training_labels_repo = DARTS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
||||
|
||||
|
||||
def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str:
|
||||
|
|
@ -52,9 +53,12 @@ def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
|||
return vizfile
|
||||
|
||||
|
||||
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
rtsfile = DARTS_DIR / f"{gridname}_darts.parquet"
|
||||
if labels:
|
||||
rtsfile = DARTS_DIR / f"{gridname}_darts-mllabels.parquet"
|
||||
else:
|
||||
rtsfile = DARTS_DIR / f"{gridname}_darts.parquet"
|
||||
return rtsfile
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue