Create Ensemble Datasets

This commit is contained in:
Tobias Hölzer 2025-12-09 17:10:43 +01:00
parent 33c9667383
commit 67030c9f0d
10 changed files with 839 additions and 626 deletions

3
.gitignore vendored
View file

@ -25,3 +25,6 @@ playground.ipynb
# pixi environments
.pixi
*.egg-info
# Disable all notebook for now
notebooks/*.ipynb

1019
pixi.lock generated

File diff suppressed because it is too large Load diff

View file

@ -45,7 +45,7 @@ dependencies = [
"ultraplot>=1.63.0",
"xanimate",
"xarray>=2025.9.0",
"xdggs>=0.2.1",
"xdggs",
"xvec>=0.5.1",
"zarr[remote]>=3.1.3",
"geocube>=0.7.1,<0.8",
@ -57,12 +57,16 @@ dependencies = [
"xgboost>=3.1.1,<4",
"s3fs>=2025.10.0,<2026",
"xarray-spatial",
"cupy-xarray>=0.1.4,<0.2", "memray>=1.19.1,<2", "xarray-histogram>=0.2.2,<0.3", "antimeridian>=0.4.5,<0.5", "duckdb>=1.4.2,<2",
"cupy-xarray>=0.1.4,<0.2",
"memray>=1.19.1,<2",
"xarray-histogram>=0.2.2,<0.3",
"antimeridian>=0.4.5,<0.5",
"duckdb>=1.4.2,<2",
]
[project.scripts]
create-grid = "entropice.grids:main"
darts = "entropice.darts:main"
darts = "entropice.darts:cli"
alpha-earth = "entropice.alphaearth:main"
era5 = "entropice.era5:cli"
arcticdem = "entropice.arcticdem:cli"
@ -86,6 +90,7 @@ entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "re
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
xanimate = { git = "https://github.com/davbyr/xAnimate" }
xdem = { git = "https://github.com/GlacioHack/xdem" }
xdggs = { git = "https://github.com/relativityhd/xdggs", branch = "feature/make-plotting-useful" }
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
cudf-cu12 = { index = "nvidia" }
cuml-cu12 = { index = "nvidia" }
@ -136,3 +141,4 @@ cudnn = ">=9.13.1.26,<10"
cusparselt = ">=0.8.1.1,<0.9"
cuda-version = "12.9.*"
rapids = ">=25.10.0,<26"
healpix-geo = ">=0.0.6"

View file

@ -1,11 +1,11 @@
#! /bin/bash
pixi run darts --grid hex --level 3
pixi run darts --grid hex --level 4
pixi run darts --grid hex --level 5
pixi run darts --grid hex --level 6
pixi run darts --grid healpix --level 6
pixi run darts --grid healpix --level 7
pixi run darts --grid healpix --level 8
pixi run darts --grid healpix --level 9
pixi run darts --grid healpix --level 10
pixi run darts extract_darts_mllabels --grid hex --level 3
pixi run darts extract_darts_mllabels --grid hex --level 4
pixi run darts extract_darts_mllabels --grid hex --level 5
pixi run darts extract_darts_mllabels --grid hex --level 6
pixi run darts extract_darts_mllabels --grid healpix --level 6
pixi run darts extract_darts_mllabels --grid healpix --level 7
pixi run darts extract_darts_mllabels --grid healpix --level 8
pixi run darts extract_darts_mllabels --grid healpix --level 9
pixi run darts extract_darts_mllabels --grid healpix --level 10

View file

@ -1,21 +1,21 @@
#!/bin/bash
pixi run alpha-earth download --grid hex --level 3
pixi run alpha-earth download --grid hex --level 4
pixi run alpha-earth download --grid hex --level 5
pixi run alpha-earth download --grid hex --level 6
pixi run alpha-earth download --grid healpix --level 6
pixi run alpha-earth download --grid healpix --level 7
pixi run alpha-earth download --grid healpix --level 8
pixi run alpha-earth download --grid healpix --level 9
pixi run alpha-earth download --grid healpix --level 10
# pixi run alpha-earth download --grid hex --level 3
# pixi run alpha-earth download --grid hex --level 4
# pixi run alpha-earth download --grid hex --level 5
# pixi run alpha-earth download --grid hex --level 6
# pixi run alpha-earth download --grid healpix --level 6
# pixi run alpha-earth download --grid healpix --level 7
# pixi run alpha-earth download --grid healpix --level 8
# pixi run alpha-earth download --grid healpix --level 9
# pixi run alpha-earth download --grid healpix --level 10
pixi run alpha-earth combine-to-zarr --grid hex --level 3
pixi run alpha-earth combine-to-zarr --grid hex --level 4
pixi run alpha-earth combine-to-zarr --grid hex --level 5
pixi run alpha-earth combine-to-zarr --grid hex --level 6
# pixi run alpha-earth combine-to-zarr --grid hex --level 6
pixi run alpha-earth combine-to-zarr --grid healpix --level 6
pixi run alpha-earth combine-to-zarr --grid healpix --level 7
pixi run alpha-earth combine-to-zarr --grid healpix --level 8
pixi run alpha-earth combine-to-zarr --grid healpix --level 9
pixi run alpha-earth combine-to-zarr --grid healpix --level 10
# pixi run alpha-earth combine-to-zarr --grid healpix --level 10

View file

@ -5,14 +5,18 @@ Date: October 2025
"""
import warnings
from collections.abc import Generator
from typing import Literal
import cudf
import cuml.cluster
import cyclopts
import ee
import geemap
import geopandas as gpd
import numpy as np
import pandas as pd
import sklearn.cluster
import xarray as xr
import xdggs
from rich import pretty, print, traceback
@ -33,6 +37,22 @@ cli = cyclopts.App(name="alpha-earth")
# 7454521782,230147807,10000000.
def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[pd.DataFrame]:
# Simple partitioning by splitting the GeoDataFrame into n_partitions parts
centroids = pd.DataFrame({"x": grid_gdf.geometry.centroid.x, "y": grid_gdf.geometry.centroid.y})
# use cuml and cudf if len of centroids is larger than 100000
if len(centroids) > 100000:
print(f"Using cuML KMeans for partitioning {len(centroids)} centroids")
centroids_cudf = cudf.DataFrame.from_pandas(centroids)
kmeans = cuml.cluster.KMeans(n_clusters=n_partitions, random_state=42)
labels = kmeans.fit_predict(centroids_cudf).to_pandas().to_numpy()
else:
labels = sklearn.cluster.KMeans(n_clusters=n_partitions, random_state=42).fit_predict(centroids)
for i in range(n_partitions):
yield grid_gdf.iloc[np.where(labels == i)[0]]
@cli.command()
def download(grid: Literal["hex", "healpix"], level: int):
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
@ -77,11 +97,11 @@ def download(grid: Literal["hex", "healpix"], level: int):
return feature.set(mean_dict)
# Process grid in batches of 100
batch_size = 50
batch_size = 25
all_results = []
n_batches = len(grid_gdf) // batch_size
for batch_num, batch_grid in track(
enumerate(np.array_split(grid_gdf, n_batches)),
for batch_grid in track(
_batch_grid(grid_gdf, n_batches),
description="Processing batches...",
total=n_batches,
):
@ -121,6 +141,7 @@ def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
np.nan,
dims=("year", "cell_ids", "band", "agg"),
coords={"year": years, "cell_ids": cell_ids, "band": bands, "agg": aggs},
name="embeddings",
).astype(np.float32)
# ? These attributes are needed for xdggs
@ -141,8 +162,9 @@ def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
a = xdggs.decode(a)
a = a.chunk({"year": 1, "cell_ids": 100000, "band": 64, "agg": 12})
zarr_path = get_embeddings_store(grid, level)
a.to_zarr(zarr_path, consolidated=False, mode="w", encoding=codecs.from_ds(a))
a.to_zarr(zarr_path, consolidated=False, mode="w", encoding=codecs.from_da(a))
print(f"Saved combined embeddings to {zarr_path}.")

View file

@ -36,3 +36,24 @@ def from_ds(
if ds[var].dtype == "float64":
encoding[var]["dtype"] = "float32"
return encoding
def from_da(da: xr.DataArray, store_floats_as_float32: bool = True) -> dict:
"""Create compression encoding for zarr DataArray storage.
Creates Blosc compression configuration for the DataArray
using zstd compression with level 5.
Args:
da (xr.DataArray): The xarray DataArray to create encoding for.
store_floats_as_float32 (bool, optional): Whether to store floating point data as float32.
Defaults to True.
Returns:
dict: Encoding dictionary with compression settings.
"""
encoding = {"compressors": BloscCodec(cname="zstd", clevel=5)}
if store_floats_as_float32 and da.dtype == "float64":
encoding["dtype"] = "float32"
return {da.name: encoding}

View file

@ -10,17 +10,21 @@ from typing import Literal
import cyclopts
import geopandas as gpd
import pandas as pd
from rich import pretty, print, traceback
from rich.progress import track
from stopuhr import stopwatch
from entropice import grids
from entropice.paths import dartsl2_cov_file, dartsl2_file, get_darts_rts_file
from entropice.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file
traceback.install()
pretty.install()
cli = cyclopts.App(name="darts-rts")
@cli.command()
def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
"""Extract RTS labels from DARTS dataset.
@ -81,8 +85,61 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
stopwatch.summary()
@cli.command()
def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int):
with stopwatch("Load data"):
grid_gdf = grids.open(grid, level)
darts_mllabels = (
gpd.GeoDataFrame(
pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/TrainingLabel*.gpkg")])
)
.reset_index(drop=True)
.to_crs(grid_gdf.crs)
)
# Filter out invalid labels
darts_mllabels = darts_mllabels[darts_mllabels.geometry.is_valid]
# Remove overlapping labels by dissolving
darts_mllabels = darts_mllabels[["geometry"]].dissolve().explode()
darts_cov_mllabels = (
gpd.GeoDataFrame(
pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/ImageFootprints*.gpkg")])
)
.reset_index(drop=True)
.to_crs(grid_gdf.crs)
)
darts_cov_mllabels = darts_cov_mllabels[["geometry"]].dissolve().explode()
with stopwatch("Extract RTS labels"):
grid_mllabels = grid_gdf.overlay(darts_mllabels.to_crs(grid_gdf.crs), how="intersection")
grid_cov_mllabels = grid_gdf.overlay(darts_cov_mllabels.to_crs(grid_gdf.crs), how="intersection")
with stopwatch("Processing RTS"):
counts = grid_mllabels.groupby("cell_id").size()
grid_gdf["dartsml_rts_count"] = grid_gdf.cell_id.map(counts)
areas = grid_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
grid_gdf["dartsml_rts_area"] = grid_gdf.cell_id.map(areas)
areas_cov = grid_cov_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
grid_gdf["dartsml_covered_area"] = grid_gdf.cell_id.map(areas_cov)
grid_gdf["dartsml_coverage"] = grid_gdf["dartsml_covered_area"] / grid_gdf.geometry.area
grid_gdf["dartsml_rts_density"] = grid_gdf["dartsml_rts_area"] / grid_gdf["dartsml_covered_area"]
# Apply corrections to NaNs
covered = ~grid_gdf["dartsml_coverage"].isna()
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
grid_gdf["darts_has_coverage"] = ~grid_gdf["dartsml_coverage"].isna()
grid_gdf["darts_has_rts"] = ~grid_gdf["dartsml_rts_count"].isna()
output_path = get_darts_rts_file(grid, level, labels=True)
grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}")
stopwatch.summary()
def main(): # noqa: D103
cyclopts.run(extract_darts_rts)
cli()
if __name__ == "__main__":

View file

@ -1,8 +1,19 @@
"""Training dataset preparation and model training."""
"""Training dataset preparation and model training.
Naming conventions:
- Ensemble Dataset
- Member-Datasets: selection based on config
- L2-Datasets: ready to use XDGGS Xarray Datasets
- Currently implemented: embeddings (alpha-earth), ERA5-seasonal, ERA5-shoulder, ERA5-yearly, Arcticdem32m
- All L2 Datasets can be multidimensional, but all having at least the cell_ids dimension
- All L2 Datasets have at least one data variable
- Dimensions of L2 Datasets are e.g. time or aggregation
"""
from dataclasses import dataclass, field
from typing import Literal
import cyclopts
import geopandas as gpd
import pandas as pd
import seaborn as sns
@ -11,12 +22,7 @@ from rich import pretty, traceback
from sklearn import set_config
from stopuhr import stopwatch
from entropice.paths import (
get_darts_rts_file,
get_embeddings_store,
get_era5_stores,
get_train_dataset_file,
)
import entropice.paths
traceback.install()
pretty.install()
@ -26,105 +32,182 @@ set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid")
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
seasons = {10: "winter", 4: "summer"}
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
def _prep_era5(
rts: gpd.GeoDataFrame,
temporal: Literal["yearly", "seasonal", "shoulder"],
grid: Literal["hex", "healpix"],
level: int,
) -> pd.DataFrame:
era5_df = []
era5_store = get_era5_stores(temporal, grid=grid, level=level)
era5 = xr.open_zarr(era5_store, consolidated=False)
era5 = era5.sel(cell_ids=rts["cell_id"].values)
for var in era5.data_vars:
df = era5[var].drop_vars("spatial_ref").to_dataframe()
if temporal == "yearly":
df["t"] = df.index.get_level_values("time").year
elif temporal == "seasonal":
df["t"] = (
df.index.get_level_values("time")
.month.map(lambda x: seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
elif temporal == "shoulder":
df["t"] = (
df.index.get_level_values("time")
.month.map(lambda x: shoulder_seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
df = (
df.pivot_table(index="cell_ids", columns="t", values=var)
.rename(columns=lambda x: f"{var}_{x}")
.rename_axis(None, axis=1)
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
if temporal == "yearly":
return df.index.get_level_values("time").year
elif temporal == "seasonal":
seasons = {10: "winter", 4: "summer"}
return (
df.index.get_level_values("time")
.month.map(lambda x: seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
elif temporal == "shoulder":
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
return (
df.index.get_level_values("time")
.month.map(lambda x: shoulder_seasons.get(x))
.str.cat(df.index.get_level_values("time").year.astype(str), sep="_")
)
era5_df.append(df)
era5_df = pd.concat(era5_df, axis=1)
era5_df = era5_df.rename(columns={col: f"era5_{col}" for col in era5_df.columns if col != "cell_id"})
return era5_df
@stopwatch("Prepare embeddings data")
def _prep_embeddings(rts: gpd.GeoDataFrame, grid: Literal["hex", "healpix"], level: int) -> pd.DataFrame:
embs_store = get_embeddings_store(grid=grid, level=level)
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
embeddings = embeddings.sel(cell=rts["cell_id"].values)
embeddings_df = embeddings.to_dataframe(name="value")
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
embeddings_df = embeddings_df.rename(
columns={col: f"embeddings_{col}" for col in embeddings_df.columns if col != "cell_id"}
)
return embeddings_df
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
def prepare_dataset(grid: Literal["hex", "healpix"], level: int, filter_target: bool = False):
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
@dataclass
class DatasetEnsemble:
grid: Literal["hex", "healpix"]
level: int
target: Literal["darts_rts", "darts_mllabels"]
members: list[L2Dataset]
dimension_filters: dict[L2Dataset, dict[str, list]] = field(default_factory=dict)
variable_filters: dict[L2Dataset, list[str]] = field(default_factory=dict)
filter_target: str | Literal[False] = False
add_lonlat: bool = True
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
if member == "AlphaEarth":
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
store = entropice.paths.get_era5_stores(member.split("-")[1], grid=self.grid, level=self.level)
elif member == "ArcticDEM":
store = entropice.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
else:
raise NotImplementedError(f"Member {member} not implemented.")
"""
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
# Filter to coverage
if filter_target:
rts = rts[rts["darts_has_coverage"]]
# Convert hex cell_id to int
if grid == "hex":
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
ds = xr.open_zarr(store, consolidated=False)
# Add the lat / lon of the cell centers
rts["lon"] = rts.geometry.centroid.x
rts["lat"] = rts.geometry.centroid.y
# Apply variable filters
if member in self.variable_filters:
assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, (
f"Invalid variable filter for {member=}: {self.variable_filters[member]}"
" Variable filter values must be a list with one or more entries."
)
ds = ds[self.variable_filters[member]]
# Get era5 data
era5_yearly = _prep_era5(rts, "yearly", grid, level)
era5_seasonal = _prep_era5(rts, "seasonal", grid, level)
era5_shoulder = _prep_era5(rts, "shoulder", grid, level)
# Apply dimension filters
if member in self.dimension_filters:
for dim, values in self.dimension_filters[member].items():
assert isinstance(values, list) and len(values) >= 1, (
f"Invalid dimension filter for {dim=}: {values}"
" Dimension filter values must be a list with one or more entries."
)
ds = ds.sel({dim: values})
# Get embeddings data
embeddings = _prep_embeddings(rts, grid, level)
# Delete all coordinates which are not in the dimension
for coord in ds.coords:
if coord not in ds.dims:
ds = ds.drop_vars(coord)
# Combine datasets by cell id / cell
with stopwatch("Combine datasets"):
dataset = rts.set_index("cell_id").join(era5_yearly).join(era5_seasonal).join(era5_shoulder).join(embeddings)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
# Only load target cell ids
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(targets["cell_id"].values))
ds = ds.sel(cell_ids=list(intersecting_cell_ids))
dataset_file = get_train_dataset_file(grid=grid, level=level)
dataset.reset_index().to_parquet(dataset_file)
# Actually read data into memory
if not lazy:
ds.load()
return ds
@stopwatch("Loading targets")
def _read_target(self) -> gpd.GeoDataFrame:
if self.target == "darts_rts":
target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level)
elif self.target == "darts_mllabels":
target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level, labels=True)
else:
raise NotImplementedError(f"Target {self.target} not implemented.")
targets = gpd.read_parquet(target_store)
def main():
cyclopts.run(prepare_dataset)
# Filter to coverage
if self.filter_target:
targets = targets[targets[self.filter_target]]
# Convert hex cell_id to int
if self.grid == "hex":
targets["cell_id"] = targets["cell_id"].apply(lambda x: int(x, 16))
# Add the lat / lon of the cell centers
if self.add_lonlat:
targets["lon"] = targets.geometry.centroid.x
targets["lat"] = targets.geometry.centroid.y
if __name__ == "__main__":
main()
return targets
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
def _prep_era5(
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
) -> pd.DataFrame:
era5_df = []
era5 = self._read_member(f"ERA5-{temporal}", targets)
for var in era5.data_vars:
df = era5[var].to_dataframe()
df["t"] = _get_era5_tempus(df, temporal)
# If aggregations is not in dims, we can pivot directly
if "aggregations" not in era5.dims:
df = (
df.pivot_table(index="cell_ids", columns="t", values=var)
.rename(columns=lambda x: f"era5_{var}_{x}")
.rename_axis(None, axis=1)
)
else:
df = df.pivot_table(index="cell_ids", columns=["t", "aggregations"], values=var)
df.columns = [f"era5_{var}_{t}_{agg}" for t, agg in df.columns]
era5_df.append(df)
era5_df = pd.concat(era5_df, axis=1)
return era5_df
@stopwatch("Prepare embeddings data")
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
embeddings_df = embeddings.to_dataframe(name="value")
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
return embeddings_df
@stopwatch("Prepare arcticdem data")
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
arcticdem = self._read_member("ArcticDEM", targets)
arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations")
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
return arcticdem_df
def print_stats(self):
targets = self._read_target()
print(f"=== Target: {self.target}")
print(f"\tNumber of target samples: {len(targets)}")
n_cols = 2 if self.add_lonlat else 0 # Lat and Lon
for member in self.members:
ds = self._read_member(member, targets, lazy=True)
print(f"=== Member: {member}")
print(f"\tVariables ({len(ds.data_vars)}): {list(ds.data_vars)}")
print(f"\tDimensions: {dict(ds.sizes)}")
print(f"\tCoordinates: {list(ds.coords)}")
n_cols_member = len(ds.data_vars)
for dim in ds.sizes:
if dim != "cell_ids":
n_cols_member *= ds.sizes[dim]
print(f"\tNumber of features from member: {n_cols_member}")
n_cols += n_cols_member
print(f"=== Total number of features in dataset: {n_cols}")
def create(self) -> pd.DataFrame:
targets = self._read_target()
member_dfs = []
for member in self.members:
if member.startswith("ERA5"):
member_dfs.append(self._prep_era5(targets, member.split("-")[1]))
elif member == "AlphaEarth":
member_dfs.append(self._prep_embeddings(targets))
elif member == "ArcticDEM":
member_dfs.append(self._prep_arcticdem(targets))
else:
raise NotImplementedError(f"Member {member} not implemented.")
with stopwatch("Combine datasets"):
dataset = targets.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
return dataset

View file

@ -34,6 +34,7 @@ watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
dartsl2_cov_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
darts_ml_training_labels_repo = DARTS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str:
@ -52,9 +53,12 @@ def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path:
return vizfile
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int) -> Path:
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path:
gridname = _get_gridname(grid, level)
rtsfile = DARTS_DIR / f"{gridname}_darts.parquet"
if labels:
rtsfile = DARTS_DIR / f"{gridname}_darts-mllabels.parquet"
else:
rtsfile = DARTS_DIR / f"{gridname}_darts.parquet"
return rtsfile