Refactor dataset ensemble to allow different temporal modes
This commit is contained in:
parent
1495f71ac9
commit
231caa62e7
6 changed files with 708 additions and 505 deletions
|
|
@ -1,15 +1,27 @@
|
||||||
#! /bin/bash
|
#! /bin/bash
|
||||||
|
|
||||||
# pixi shell
|
# pixi shell
|
||||||
darts extract-darts-rts --grid hex --level 3
|
darts extract-darts-v1 --grid hex --level 3
|
||||||
darts extract-darts-rts --grid hex --level 4
|
darts extract-darts-v1 --grid hex --level 4
|
||||||
darts extract-darts-rts --grid hex --level 5
|
darts extract-darts-v1 --grid hex --level 5
|
||||||
darts extract-darts-rts --grid hex --level 6
|
darts extract-darts-v1 --grid hex --level 6
|
||||||
darts extract-darts-rts --grid healpix --level 6
|
darts extract-darts-v1 --grid healpix --level 6
|
||||||
darts extract-darts-rts --grid healpix --level 7
|
darts extract-darts-v1 --grid healpix --level 7
|
||||||
darts extract-darts-rts --grid healpix --level 8
|
darts extract-darts-v1 --grid healpix --level 8
|
||||||
darts extract-darts-rts --grid healpix --level 9
|
darts extract-darts-v1 --grid healpix --level 9
|
||||||
darts extract-darts-rts --grid healpix --level 10
|
darts extract-darts-v1 --grid healpix --level 10
|
||||||
|
|
||||||
|
|
||||||
|
darts extract-darts-v1-aggregated --grid hex --level 3
|
||||||
|
darts extract-darts-v1-aggregated --grid hex --level 4
|
||||||
|
darts extract-darts-v1-aggregated --grid hex --level 5
|
||||||
|
darts extract-darts-v1-aggregated --grid hex --level 6
|
||||||
|
darts extract-darts-v1-aggregated --grid healpix --level 6
|
||||||
|
darts extract-darts-v1-aggregated --grid healpix --level 7
|
||||||
|
darts extract-darts-v1-aggregated --grid healpix --level 8
|
||||||
|
darts extract-darts-v1-aggregated --grid healpix --level 9
|
||||||
|
darts extract-darts-v1-aggregated --grid healpix --level 10
|
||||||
|
|
||||||
|
|
||||||
darts extract-darts-mllabels --grid hex --level 3
|
darts extract-darts-mllabels --grid hex --level 3
|
||||||
darts extract-darts-mllabels --grid hex --level 4
|
darts extract-darts-mllabels --grid hex --level 4
|
||||||
|
|
|
||||||
|
|
@ -9,28 +9,140 @@ Date: October 2025
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from rich import pretty, print, traceback
|
import xarray as xr
|
||||||
from rich.progress import track
|
import xdggs
|
||||||
|
from rich import pretty, traceback
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice.spatial import grids
|
from entropice.spatial import grids
|
||||||
from entropice.utils.paths import (
|
from entropice.utils.paths import (
|
||||||
darts_ml_training_labels_repo,
|
DARTS_MLLABELS_DIR,
|
||||||
dartsl2_cov_file,
|
DARTS_V1_DIR,
|
||||||
dartsl2_file,
|
get_darts_file,
|
||||||
get_darts_rts_file,
|
|
||||||
)
|
)
|
||||||
from entropice.utils.types import Grid
|
from entropice.utils.types import Grid
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
||||||
|
darts_v1_l2_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
||||||
|
darts_v1_l2_cov_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
||||||
|
darts_ml_training_labels_repo = DARTS_MLLABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
||||||
|
|
||||||
cli = cyclopts.App(name="darts-rts")
|
cli = cyclopts.App(name="darts-rts")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_grid(grid: Grid, level: int) -> tuple[gpd.GeoDataFrame, xr.DataArray]:
|
||||||
|
grid_gdf = grids.open(grid, level)
|
||||||
|
if grid == "hex":
|
||||||
|
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16))
|
||||||
|
cell_areas = xr.DataArray(
|
||||||
|
grid_gdf["cell_area"].to_numpy(), dims=["cell_id"], coords={"cell_id": grid_gdf["cell_id"].to_numpy()}
|
||||||
|
)
|
||||||
|
# We only want to use the geometry and the cell_id
|
||||||
|
grid_gdf = grid_gdf[["cell_id", "geometry"]]
|
||||||
|
return grid_gdf, cell_areas
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Convert xdgss")
|
||||||
|
def _convert_xdggs(darts: xr.Dataset, grid: Grid, level: int) -> xr.Dataset:
|
||||||
|
darts = darts.rename({"cell_id": "cell_ids"})
|
||||||
|
darts["cell_ids"] = darts["cell_ids"].astype("uint64")
|
||||||
|
gridinfo = {
|
||||||
|
"grid_name": "h3" if grid == "hex" else grid,
|
||||||
|
"level": level,
|
||||||
|
}
|
||||||
|
if grid == "healpix":
|
||||||
|
gridinfo["indexing_scheme"] = "nested"
|
||||||
|
darts.cell_ids.attrs = gridinfo
|
||||||
|
darts = xdggs.decode(darts)
|
||||||
|
return darts
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Process RTS grid")
|
||||||
|
def _process_rts_grid(
|
||||||
|
rts: gpd.GeoDataFrame,
|
||||||
|
cov: gpd.GeoDataFrame,
|
||||||
|
cell_areas: xr.DataArray,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
assert "cell_id" in rts.columns, "RTS data must contain cell_id column."
|
||||||
|
assert "cell_id" in cov.columns, "Coverage data must contain cell_id column."
|
||||||
|
|
||||||
|
cov["area_km2"] = cov.geometry.area / 1e6
|
||||||
|
covered_area: pd.Series = cov.pivot_table(
|
||||||
|
index="cell_id",
|
||||||
|
values="area_km2",
|
||||||
|
aggfunc="sum",
|
||||||
|
fill_value=0,
|
||||||
|
)["area_km2"]
|
||||||
|
covered_area: xr.DataArray = xr.DataArray.from_series(covered_area)
|
||||||
|
coverage = covered_area / cell_areas.sel(cell_id=covered_area.cell_id)
|
||||||
|
rts["area_km2"] = rts.geometry.area / 1e6
|
||||||
|
rts = rts.pivot_table(
|
||||||
|
index="cell_id",
|
||||||
|
values=["geometry", "area_km2"],
|
||||||
|
fill_value=0,
|
||||||
|
aggfunc={"geometry": "count", "area_km2": "sum"}, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
rts: xr.Dataset = xr.Dataset.from_dataframe(rts)
|
||||||
|
darts = xr.merge(
|
||||||
|
[rts, covered_area.rename("covered_area_km2"), coverage.rename("coverage")], join="outer", fill_value=0
|
||||||
|
)
|
||||||
|
darts["density"] = (darts["area_km2"] / darts["covered_area_km2"]).fillna(0)
|
||||||
|
return darts
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Process RTS yearly grid")
|
||||||
|
def _process_rts_yearly_grid(
|
||||||
|
rts: gpd.GeoDataFrame,
|
||||||
|
cov: gpd.GeoDataFrame,
|
||||||
|
cell_areas: xr.DataArray,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
assert "cell_id" in rts.columns, "RTS data must contain cell_id column."
|
||||||
|
assert "cell_id" in cov.columns, "Coverage data must contain cell_id column."
|
||||||
|
|
||||||
|
cov["area_km2"] = cov.geometry.area / 1e6
|
||||||
|
covered_area: pd.Series = cov.pivot_table( # type: ignore[var-annotated] # noqa: PD010
|
||||||
|
index="cell_id",
|
||||||
|
columns="year",
|
||||||
|
values="area_km2",
|
||||||
|
aggfunc="sum",
|
||||||
|
fill_value=0,
|
||||||
|
).unstack()
|
||||||
|
covered_area: xr.DataArray = xr.DataArray.from_series(covered_area)
|
||||||
|
coverage = covered_area / cell_areas.sel(cell_id=covered_area.cell_id)
|
||||||
|
rts: pd.DataFrame = (
|
||||||
|
rts.pivot_table( # type: ignore[var-annotated] # noqa: PD013
|
||||||
|
index="cell_id",
|
||||||
|
columns="year",
|
||||||
|
values=["geometry", "area_km2"],
|
||||||
|
fill_value=0,
|
||||||
|
aggfunc={"geometry": "count", "area_km2": "sum"}, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
.stack(future_stack=True)
|
||||||
|
.rename(columns={"geometry": "count"})
|
||||||
|
)
|
||||||
|
rts: xr.Dataset = xr.Dataset.from_dataframe(rts)
|
||||||
|
darts = xr.merge(
|
||||||
|
[rts, covered_area.rename("covered_area_km2"), coverage.rename("coverage")], join="outer", fill_value=0
|
||||||
|
)
|
||||||
|
darts["density"] = (darts["area_km2"] / darts["covered_area_km2"]).fillna(0)
|
||||||
|
return darts
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def extract_darts_rts(grid: Grid, level: int):
|
def extract_darts_v1(grid: Grid, level: int):
|
||||||
"""Extract RTS labels from DARTS dataset.
|
"""Extract RTS labels from DARTS-v1 Level-2 dataset.
|
||||||
|
|
||||||
|
Creates a Darts-v1 xarray Dataset on the specified grid and level.
|
||||||
|
The Dataset contains the following variables:
|
||||||
|
- count: Number of RTS in the cell
|
||||||
|
- area_km2: Total area of RTS in the cell (in km^2)
|
||||||
|
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
|
||||||
|
- coverage: Fraction of the cell covered by DARTS
|
||||||
|
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
|
||||||
|
Since the DARTS-v1 Level-2 dataset contains yearly data, all variables are indexed by year as well.
|
||||||
|
Thus each variable has dimensions (cell_ids, year).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grid (Grid): The grid type to use.
|
grid (Grid): The grid type to use.
|
||||||
|
|
@ -38,66 +150,81 @@ def extract_darts_rts(grid: Grid, level: int):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with stopwatch("Load data"):
|
with stopwatch("Load data"):
|
||||||
darts_l2 = gpd.read_parquet(dartsl2_file)
|
darts_l2 = gpd.read_parquet(darts_v1_l2_file)
|
||||||
darts_cov_l2 = gpd.read_parquet(dartsl2_cov_file)
|
darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file)
|
||||||
grid_gdf = grids.open(grid, level)
|
grid_gdf, cell_areas = _load_grid(grid, level)
|
||||||
|
|
||||||
with stopwatch("Extract RTS labels"):
|
with stopwatch("Assign RTS to grid"):
|
||||||
grid_l2 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection")
|
grid_l2 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection")
|
||||||
grid_cov_l2 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection")
|
grid_cov_l2 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection")
|
||||||
|
|
||||||
years = list(grid_cov_l2["year"].unique())
|
darts = _process_rts_yearly_grid(grid_l2, grid_cov_l2, cell_areas)
|
||||||
for year in track(years, total=len(years), description="Processing years..."):
|
|
||||||
with stopwatch("Processing RTS", log=False):
|
|
||||||
subset = grid_l2[grid_l2["year"] == year]
|
|
||||||
subset_cov = grid_cov_l2[grid_cov_l2["year"] == year]
|
|
||||||
|
|
||||||
counts = subset.groupby("cell_id").size()
|
darts = _convert_xdggs(darts, grid, level)
|
||||||
grid_gdf[f"darts_{year}_rts_count"] = grid_gdf.cell_id.map(counts)
|
output_path = get_darts_file(grid, level, version="v1")
|
||||||
|
with stopwatch(f"Writing Darts v1 to {output_path}"):
|
||||||
|
darts.to_zarr(output_path, consolidated=False, mode="w")
|
||||||
|
|
||||||
areas = subset.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
|
|
||||||
grid_gdf[f"darts_{year}_rts_area"] = grid_gdf.cell_id.map(areas)
|
|
||||||
|
|
||||||
areas_cov = subset_cov.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
|
@cli.command()
|
||||||
grid_gdf[f"darts_{year}_covered_area"] = grid_gdf.cell_id.map(areas_cov)
|
def extract_darts_v1_aggregated(grid: Grid, level: int):
|
||||||
grid_gdf[f"darts_{year}_coverage"] = grid_gdf[f"darts_{year}_covered_area"] / grid_gdf.geometry.area
|
"""Extract RTS labels from DARTS-v1 Level-3 dataset.
|
||||||
|
|
||||||
grid_gdf[f"darts_{year}_rts_density"] = (
|
Creates a Darts-v1 xarray Dataset on the specified grid and level.
|
||||||
grid_gdf[f"darts_{year}_rts_area"] / grid_gdf[f"darts_{year}_covered_area"]
|
The Dataset contains the following variables:
|
||||||
)
|
- count: Number of RTS in the cell
|
||||||
|
- area_km2: Total area of RTS in the cell (in km^2)
|
||||||
|
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
|
||||||
|
- coverage: Fraction of the cell covered by DARTS
|
||||||
|
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
|
||||||
|
Since the DARTS-v1 Level-2 dataset contains yearly data, the data is dissolved then exploded to obtain Level-3 data.
|
||||||
|
Thus each variable has only the dimension (cell_ids).
|
||||||
|
|
||||||
# Apply corrections to NaNs
|
Args:
|
||||||
covered = ~grid_gdf[f"darts_{year}_coverage"].isna()
|
grid (Grid): The grid type to use.
|
||||||
grid_gdf.loc[covered, f"darts_{year}_rts_count"] = grid_gdf.loc[covered, f"darts_{year}_rts_count"].fillna(
|
level (int): The grid level to use.
|
||||||
0.0
|
|
||||||
)
|
|
||||||
grid_gdf.loc[covered, f"darts_{year}_rts_density"] = grid_gdf.loc[
|
|
||||||
covered, f"darts_{year}_rts_density"
|
|
||||||
].fillna(0.0)
|
|
||||||
grid_gdf[f"darts_{year}_has_coverage"] = covered
|
|
||||||
grid_gdf[f"darts_{year}_has_rts"] = grid_gdf[f"darts_{year}_rts_count"] > 0
|
|
||||||
|
|
||||||
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
|
"""
|
||||||
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
|
with stopwatch("Load data"):
|
||||||
|
darts_l2 = gpd.read_parquet(darts_v1_l2_file)
|
||||||
|
darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file)
|
||||||
|
grid_gdf, cell_areas = _load_grid(grid, level)
|
||||||
|
# Remove overlapping labels by dissolving
|
||||||
|
darts_l2 = darts_l2[["geometry"]].dissolve().explode()
|
||||||
|
darts_cov_l2 = darts_cov_l2[["geometry"]].dissolve().explode()
|
||||||
|
|
||||||
darts_counts_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_count")]
|
with stopwatch("Extract RTS labels"):
|
||||||
darts_counts = grid_gdf[darts_counts_columns]
|
grid_l3 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection")
|
||||||
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1)
|
grid_cov_l3 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection")
|
||||||
|
|
||||||
darts_density_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_density")]
|
darts = _process_rts_grid(grid_l3, grid_cov_l3, cell_areas)
|
||||||
darts_density = grid_gdf[darts_density_columns]
|
darts = _convert_xdggs(darts, grid, level)
|
||||||
grid_gdf["darts_rts_density"] = darts_density.dropna(axis=0, how="all").max(axis=1)
|
output_path = get_darts_file(grid, level, version="v1-l3")
|
||||||
|
with stopwatch(f"Writing Darts v1 l3 to {output_path}"):
|
||||||
output_path = get_darts_rts_file(grid, level)
|
darts.to_zarr(output_path, consolidated=False, mode="w")
|
||||||
grid_gdf.to_parquet(output_path)
|
|
||||||
print(f"Saved RTS labels to {output_path}")
|
|
||||||
stopwatch.summary()
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def extract_darts_mllabels(grid: Grid, level: int):
|
def extract_darts_mllabels(grid: Grid, level: int):
|
||||||
|
"""Extract RTS labels from the DARTS-mllabels dataset.
|
||||||
|
|
||||||
|
Creates a Darts-mllabels xarray Dataset on the specified grid and level.
|
||||||
|
The Dataset contains the following variables:
|
||||||
|
- count: Number of RTS in the cell
|
||||||
|
- area_km2: Total area of RTS in the cell (in km^2)
|
||||||
|
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
|
||||||
|
- coverage: Fraction of the cell covered by DARTS
|
||||||
|
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
|
||||||
|
Since the DARTS-mllabels dataset contains not much data, all variables are aggregated over the entire time period.
|
||||||
|
Thus each variable has only the dimension (cell_ids).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Grid): The grid type to use.
|
||||||
|
level (int): The grid level to use.
|
||||||
|
|
||||||
|
"""
|
||||||
with stopwatch("Load data"):
|
with stopwatch("Load data"):
|
||||||
grid_gdf = grids.open(grid, level)
|
grid_gdf, cell_areas = _load_grid(grid, level)
|
||||||
darts_mllabels = (
|
darts_mllabels = (
|
||||||
gpd.GeoDataFrame(
|
gpd.GeoDataFrame(
|
||||||
pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/TrainingLabel*.gpkg")])
|
pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/TrainingLabel*.gpkg")])
|
||||||
|
|
@ -118,34 +245,16 @@ def extract_darts_mllabels(grid: Grid, level: int):
|
||||||
.to_crs(grid_gdf.crs)
|
.to_crs(grid_gdf.crs)
|
||||||
)
|
)
|
||||||
darts_cov_mllabels = darts_cov_mllabels[["geometry"]].dissolve().explode()
|
darts_cov_mllabels = darts_cov_mllabels[["geometry"]].dissolve().explode()
|
||||||
|
|
||||||
with stopwatch("Extract RTS labels"):
|
with stopwatch("Extract RTS labels"):
|
||||||
grid_mllabels = grid_gdf.overlay(darts_mllabels.to_crs(grid_gdf.crs), how="intersection")
|
grid_mllabels = grid_gdf.overlay(darts_mllabels.to_crs(grid_gdf.crs), how="intersection")
|
||||||
grid_cov_mllabels = grid_gdf.overlay(darts_cov_mllabels.to_crs(grid_gdf.crs), how="intersection")
|
grid_cov_mllabels = grid_gdf.overlay(darts_cov_mllabels.to_crs(grid_gdf.crs), how="intersection")
|
||||||
|
|
||||||
with stopwatch("Processing RTS"):
|
darts = _process_rts_grid(grid_mllabels, grid_cov_mllabels, cell_areas)
|
||||||
counts = grid_mllabels.groupby("cell_id").size()
|
darts = _convert_xdggs(darts, grid, level)
|
||||||
grid_gdf["dartsml_rts_count"] = grid_gdf.cell_id.map(counts)
|
output_path = get_darts_file(grid, level, version="mllabels")
|
||||||
|
with stopwatch(f"Writing Darts v1 to {output_path}"):
|
||||||
areas = grid_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
|
darts.to_zarr(output_path, consolidated=False, mode="w")
|
||||||
grid_gdf["dartsml_rts_area"] = grid_gdf.cell_id.map(areas)
|
|
||||||
|
|
||||||
areas_cov = grid_cov_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
|
|
||||||
grid_gdf["dartsml_covered_area"] = grid_gdf.cell_id.map(areas_cov)
|
|
||||||
grid_gdf["dartsml_coverage"] = grid_gdf["dartsml_covered_area"] / grid_gdf.geometry.area
|
|
||||||
grid_gdf["dartsml_rts_density"] = grid_gdf["dartsml_rts_area"] / grid_gdf["dartsml_covered_area"]
|
|
||||||
|
|
||||||
# Apply corrections to NaNs
|
|
||||||
covered = ~grid_gdf["dartsml_coverage"].isna()
|
|
||||||
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
|
|
||||||
grid_gdf.loc[covered, "dartsml_rts_density"] = grid_gdf.loc[covered, "dartsml_rts_density"].fillna(0.0)
|
|
||||||
|
|
||||||
grid_gdf["dartsml_has_coverage"] = covered
|
|
||||||
grid_gdf["dartsml_has_rts"] = grid_gdf["dartsml_rts_count"] > 0
|
|
||||||
|
|
||||||
output_path = get_darts_rts_file(grid, level, labels=True)
|
|
||||||
grid_gdf.to_parquet(output_path)
|
|
||||||
print(f"Saved RTS labels to {output_path}")
|
|
||||||
stopwatch.summary()
|
|
||||||
|
|
||||||
|
|
||||||
def main(): # noqa: D103
|
def main(): # noqa: D103
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806, D105
|
||||||
"""Training dataset preparation and model training.
|
"""Training dataset preparation and model training.
|
||||||
|
|
||||||
Naming conventions:
|
Naming conventions:
|
||||||
|
|
@ -16,9 +16,9 @@ import hashlib
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from functools import cached_property
|
from functools import cache, cached_property
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, cast
|
||||||
|
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
import cyclopts
|
import cyclopts
|
||||||
|
|
@ -33,8 +33,9 @@ from sklearn import set_config
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
import entropice.spatial.grids
|
||||||
import entropice.utils.paths
|
import entropice.utils.paths
|
||||||
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task
|
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task, TemporalMode
|
||||||
|
|
||||||
traceback.install()
|
traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -44,42 +45,53 @@ set_config(array_api_dispatch=True)
|
||||||
sns.set_theme("talk", "whitegrid")
|
sns.set_theme("talk", "whitegrid")
|
||||||
|
|
||||||
|
|
||||||
covcol: dict[TargetDataset, str] = {
|
def _unstack_era5_time(era5: xr.Dataset, aggregation: Literal["yearly", "seasonal", "shoulder"]) -> xr.Dataset:
|
||||||
"darts_rts": "darts_has_coverage",
|
# In the yearly case, no unstacking is necessary, we can just rename the time dimension to year and change the coord
|
||||||
"darts_mllabels": "dartsml_has_coverage",
|
if aggregation == "yearly":
|
||||||
}
|
era5 = era5.rename({"time": "year"})
|
||||||
|
era5.coords["year"] = era5["year"].dt.year
|
||||||
|
return era5
|
||||||
|
|
||||||
taskcol: dict[Task, dict[TargetDataset, str]] = {
|
# Make the time index a MultiIndex of year and month
|
||||||
"binary": {
|
era5.coords["year"] = era5.time.dt.year
|
||||||
"darts_rts": "darts_has_rts",
|
era5.coords["month"] = era5.time.dt.month
|
||||||
"darts_mllabels": "dartsml_has_rts",
|
era5["time"] = pd.MultiIndex.from_arrays(
|
||||||
},
|
[
|
||||||
"count": {
|
era5.time.dt.year.values, # noqa: PD011
|
||||||
"darts_rts": "darts_rts_count",
|
era5.time.dt.month.values, # noqa: PD011
|
||||||
"darts_mllabels": "dartsml_rts_count",
|
],
|
||||||
},
|
names=("year", "month"),
|
||||||
"density": {
|
)
|
||||||
"darts_rts": "darts_rts_density",
|
era5 = era5.unstack("time") # noqa: PD010
|
||||||
"darts_mllabels": "dartsml_rts_density",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
|
|
||||||
time_index = pd.DatetimeIndex(df.index.get_level_values("time"))
|
|
||||||
if temporal == "yearly":
|
|
||||||
return time_index.year
|
|
||||||
elif temporal == "seasonal":
|
|
||||||
seasons = {10: "winter", 4: "summer"}
|
seasons = {10: "winter", 4: "summer"}
|
||||||
return time_index.month.map(lambda x: seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
|
|
||||||
elif temporal == "shoulder":
|
|
||||||
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
|
||||||
return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
|
month_map = seasons if aggregation == "seasonal" else shoulder_seasons
|
||||||
|
era5.coords["month"] = era5["month"].to_series().map(month_map)
|
||||||
|
return era5
|
||||||
|
|
||||||
|
|
||||||
|
def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
|
||||||
|
collapsed = ds.to_dataframe()
|
||||||
|
# Make a dummy row to avoid empty dataframe issues
|
||||||
|
use_dummy = collapsed.shape[0] == 0
|
||||||
|
if use_dummy:
|
||||||
|
collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan
|
||||||
|
pivcols = set(collapsed.index.names) - {"cell_ids"}
|
||||||
|
collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols)
|
||||||
|
collapsed.columns = ["_".join(v) for v in collapsed.columns]
|
||||||
|
if use_dummy:
|
||||||
|
collapsed = collapsed.dropna(how="all")
|
||||||
|
return collapsed
|
||||||
|
|
||||||
|
|
||||||
|
def _cell_ids_hash(cell_ids: pd.Series) -> str:
|
||||||
|
sorted_ids = np.sort(cell_ids.to_numpy())
|
||||||
|
return hashlib.blake2b(sorted_ids.tobytes(), digest_size=8).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def bin_values(
|
def bin_values(
|
||||||
values: pd.Series,
|
values: pd.Series,
|
||||||
task: Literal["count", "density"],
|
task: Literal["count_regimes", "density_regimes"],
|
||||||
none_val: float = 0,
|
none_val: float = 0,
|
||||||
) -> pd.Series:
|
) -> pd.Series:
|
||||||
"""Bin values into predefined intervals for different tasks.
|
"""Bin values into predefined intervals for different tasks.
|
||||||
|
|
@ -89,7 +101,7 @@ def bin_values(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
values (pd.Series): Pandas Series of numerical values to bin.
|
values (pd.Series): Pandas Series of numerical values to bin.
|
||||||
task (Literal["count", "density"]): Task type - 'count' or 'density'.
|
task (Literal["count_regimes", "density_regimes"]): Task type - 'count_regimes' or 'density_regimes'.
|
||||||
none_val (float, optional): Value representing 'none' or 'empty' (e.g., 0 for count). Defaults to 0.
|
none_val (float, optional): Value representing 'none' or 'empty' (e.g., 0 for count). Defaults to 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -100,15 +112,8 @@ def bin_values(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
labels_dict = {
|
labels_dict = {
|
||||||
"count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
|
"count_regimes": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
|
||||||
"density": [
|
"density_regimes": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"],
|
||||||
"Empty",
|
|
||||||
"Very Sparse",
|
|
||||||
"Sparse",
|
|
||||||
"Moderate",
|
|
||||||
"Dense",
|
|
||||||
"Very Dense",
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
labels = labels_dict[task]
|
labels = labels_dict[task]
|
||||||
|
|
||||||
|
|
@ -138,62 +143,67 @@ def bin_values(
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, eq=False)
|
@dataclass(frozen=True, eq=False)
|
||||||
class DatasetLabels:
|
class SplittedArrays:
|
||||||
binned: pd.Series
|
"""Small wrapper for train and test arrays."""
|
||||||
|
|
||||||
train: torch.Tensor | np.ndarray | cp.ndarray
|
train: torch.Tensor | np.ndarray | cp.ndarray
|
||||||
test: torch.Tensor | np.ndarray | cp.ndarray
|
test: torch.Tensor | np.ndarray | cp.ndarray
|
||||||
raw_values: pd.Series
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, eq=False)
|
||||||
|
class TrainingSet:
|
||||||
|
"""Container for the training dataset."""
|
||||||
|
|
||||||
|
targets: gpd.GeoDataFrame
|
||||||
|
features: pd.DataFrame
|
||||||
|
X: SplittedArrays
|
||||||
|
y: SplittedArrays
|
||||||
|
z: pd.Series
|
||||||
|
split: pd.Series
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def intervals(self) -> list[tuple[float, float] | tuple[int, int]]:
|
def target_intervals(self) -> list[tuple[float, float] | tuple[int, int] | tuple[None, None]]:
|
||||||
|
"""Calculate the intervals for each target category.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[tuple[float, float] | tuple[int, int] | tuple[None, None]]:
|
||||||
|
List of (min, max) tuples for each category.
|
||||||
|
|
||||||
|
"""
|
||||||
|
binned = self.targets["y"]
|
||||||
|
raw = self.targets["z"]
|
||||||
|
assert binned.dtype.name == "category", "Target labels are not categorical."
|
||||||
|
|
||||||
# For each category get the min and max values from raw_values
|
# For each category get the min and max values from raw_values
|
||||||
intervals = []
|
intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] = []
|
||||||
for category in self.binned.cat.categories:
|
for category in binned.cat.categories:
|
||||||
category_mask = self.binned == category
|
category_mask = binned == category
|
||||||
if category_mask.sum() == 0:
|
if category_mask.sum() == 0:
|
||||||
intervals.append((None, None))
|
intervals.append((None, None))
|
||||||
else:
|
else:
|
||||||
category_raw_values = self.raw_values[category_mask]
|
category_raw_values = raw[category_mask]
|
||||||
intervals.append((category_raw_values.min(), category_raw_values.max()))
|
intervals.append((category_raw_values.min(), category_raw_values.max()))
|
||||||
return intervals
|
return intervals
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def labels(self) -> list[str]:
|
def target_labels(self) -> list[str]:
|
||||||
return list(self.binned.cat.categories)
|
"""Labels of the target categories."""
|
||||||
|
binned = self.targets["y"]
|
||||||
|
assert binned.dtype.name == "category", "Target labels are not categorical."
|
||||||
@dataclass(frozen=True, eq=False)
|
return list(binned.cat.categories)
|
||||||
class DatasetInputs:
|
|
||||||
data: pd.DataFrame
|
|
||||||
train: torch.Tensor | np.ndarray | cp.ndarray
|
|
||||||
test: torch.Tensor | np.ndarray | cp.ndarray
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class CategoricalTrainingDataset:
|
|
||||||
dataset: gpd.GeoDataFrame
|
|
||||||
X: DatasetInputs
|
|
||||||
y: DatasetLabels
|
|
||||||
z: pd.Series
|
|
||||||
split: pd.Series
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.z)
|
return len(self.z)
|
||||||
|
|
||||||
|
|
||||||
class DatasetStats(TypedDict):
|
|
||||||
target: str
|
|
||||||
num_target_samples: int
|
|
||||||
members: dict[str, dict[str, object]]
|
|
||||||
total_features: int
|
|
||||||
|
|
||||||
|
|
||||||
@cyclopts.Parameter("*")
|
@cyclopts.Parameter("*")
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class DatasetEnsemble:
|
class DatasetEnsemble:
|
||||||
|
"""An ensemble of datasets for training and inference."""
|
||||||
|
|
||||||
grid: Grid
|
grid: Grid
|
||||||
level: int
|
level: int
|
||||||
target: Literal["darts_rts", "darts_mllabels"]
|
target: TargetDataset
|
||||||
members: list[L2SourceDataset] = field(
|
members: list[L2SourceDataset] = field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
"AlphaEarth",
|
"AlphaEarth",
|
||||||
|
|
@ -203,366 +213,420 @@ class DatasetEnsemble:
|
||||||
"ERA5-shoulder",
|
"ERA5-shoulder",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
|
temporal_mode: TemporalMode = "synopsis"
|
||||||
variable_filters: dict[str, list[str]] = field(default_factory=dict)
|
dimension_filters: dict[L2SourceDataset, dict[str, list]] = field(default_factory=dict)
|
||||||
filter_target: str | Literal[False] = False
|
variable_filters: dict[L2SourceDataset, list[str]] = field(default_factory=dict)
|
||||||
add_lonlat: bool = True
|
add_lonlat: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Validate filters
|
||||||
|
for member in self.members:
|
||||||
|
if member in self.variable_filters:
|
||||||
|
# This enforces the assumption that the read members always are xarray Datasets and not DataArrays
|
||||||
|
assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, (
|
||||||
|
f"Invalid variable filter for {member=}: {self.variable_filters[member]}"
|
||||||
|
" Variable filter values must be a list with one or more entries."
|
||||||
|
)
|
||||||
|
if member in self.dimension_filters:
|
||||||
|
for dim, values in self.dimension_filters[member].items():
|
||||||
|
# This enforces the assumption that we do the temporal filtering via the temporal_mode
|
||||||
|
assert dim not in ["year", "month", "time"], (
|
||||||
|
f"Invalid dimension filter for {member=}: {dim}"
|
||||||
|
" Filtering on 'year', 'month' or 'time' is not supported."
|
||||||
|
)
|
||||||
|
# This enforces the assumption that there are no empty dimensions in the Dataset
|
||||||
|
assert isinstance(values, list) and len(values) >= 1, (
|
||||||
|
f"Invalid dimension filter for {dim=}: {values}"
|
||||||
|
" Dimension filter values must be a list with one or more entries."
|
||||||
|
)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return int(self.id(), 16)
|
return int(self.id(), 16)
|
||||||
|
|
||||||
|
@cache
|
||||||
def id(self):
|
def id(self):
|
||||||
|
"""Return an unique, stable identifier based on the settings of this class."""
|
||||||
return hashlib.blake2b(
|
return hashlib.blake2b(
|
||||||
json.dumps(asdict(self), sort_keys=True).encode("utf-8"),
|
json.dumps(asdict(self), sort_keys=True).encode("utf-8"),
|
||||||
digest_size=16,
|
digest_size=16,
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def covcol(self) -> str:
|
def cell_ids(self) -> pd.Series:
|
||||||
return covcol[self.target]
|
return self.read_grid()["cell_id"]
|
||||||
|
|
||||||
def taskcol(self, task: Task) -> str:
|
@cached_property
|
||||||
return taskcol[task][self.target]
|
def geometries(self) -> pd.Series:
|
||||||
|
return self.read_grid()["geometry"]
|
||||||
|
|
||||||
def _read_member(self, member: L2SourceDataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
# @stopwatch("Reading grid")
|
||||||
if member == "AlphaEarth":
|
def read_grid(self) -> gpd.GeoDataFrame:
|
||||||
|
grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level)
|
||||||
|
|
||||||
|
# Add the lat / lon of the cell centers
|
||||||
|
if self.add_lonlat:
|
||||||
|
grid_gdf["lon"] = grid_gdf.geometry.centroid.x
|
||||||
|
grid_gdf["lat"] = grid_gdf.geometry.centroid.y
|
||||||
|
|
||||||
|
# Convert hex cell_id to int
|
||||||
|
if self.grid == "hex":
|
||||||
|
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16)).astype(np.uint64)
|
||||||
|
|
||||||
|
grid_gdf = grid_gdf.set_index("cell_id")
|
||||||
|
return grid_gdf
|
||||||
|
|
||||||
|
# @stopwatch.f("Getting target labels", print_kwargs=["stage"])
|
||||||
|
def get_targets(self, task: Task) -> gpd.GeoDataFrame:
|
||||||
|
"""Create a training target labels for a specific task.
|
||||||
|
|
||||||
|
The function reads the target dataset, filters it based on coverage,
|
||||||
|
and prepares the target labels according to the specified task.
|
||||||
|
|
||||||
|
This is quite a fast function, with the bottleneck usually being the reading of the zarr store
|
||||||
|
(milliseconds to few seconds).
|
||||||
|
|
||||||
|
This is intended to be used by another training Dataset creation function
|
||||||
|
and for easy visualization of the targets.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Task): The task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
gpd.GeoDataFrame: GeoDataFrame with target labels, geometry and raw-target values,
|
||||||
|
all indexed by the cell_ids.
|
||||||
|
Columns: ["geometry", "y", "z"], where "y" is the target label and "z" the raw value.
|
||||||
|
For categorical tasks, "y" is categorical and "z" is numerical.
|
||||||
|
For regression tasks, both "y" and "z" are numerical and identical.
|
||||||
|
|
||||||
|
"""
|
||||||
|
match (self.target, self.temporal_mode):
|
||||||
|
case ("darts_v1" | "darts_v2", "feature" | "synopsis"):
|
||||||
|
version: Literal["v1-l3", "v2-l3"] = self.target.split("_")[1] + "-l3" # ty:ignore[invalid-assignment]
|
||||||
|
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
||||||
|
targets = xr.open_zarr(target_store, consolidated=False)
|
||||||
|
case ("darts_v1" | "darts_v2", int()):
|
||||||
|
version: Literal["v1", "v2"] = self.target.split("_")[1] # ty:ignore[invalid-assignment]
|
||||||
|
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
|
||||||
|
targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
|
||||||
|
case ("darts_mllabels", str()): # Years are not supported
|
||||||
|
target_store = entropice.utils.paths.get_darts_file(
|
||||||
|
grid=self.grid, level=self.level, version="mllabels"
|
||||||
|
)
|
||||||
|
targets = xr.open_zarr(target_store, consolidated=False)
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(f"Target {self.target} on {self.temporal_mode} mode not supported.")
|
||||||
|
targets = cast(xr.Dataset, targets)
|
||||||
|
covered_cell_ids = targets["coverage"].where(targets["coverage"] > 0).dropna("cell_ids")["cell_ids"].to_series()
|
||||||
|
targets = targets.sel(cell_ids=covered_cell_ids.to_numpy())
|
||||||
|
match task:
|
||||||
|
case "binary":
|
||||||
|
z = targets["count"].to_series()
|
||||||
|
y = (z > 0).astype("category").map({False: "No RTS", True: "RTS"})
|
||||||
|
case "count_regimes":
|
||||||
|
z = targets["count"].to_series()
|
||||||
|
y = bin_values(z, task="count_regimes", none_val=0)
|
||||||
|
case "density_regimes":
|
||||||
|
z = targets["density"].to_series()
|
||||||
|
y = bin_values(z, task="density_regimes", none_val=0)
|
||||||
|
case "count":
|
||||||
|
z = targets["count"].to_series()
|
||||||
|
y = z
|
||||||
|
case "density":
|
||||||
|
z = targets["density"].to_series()
|
||||||
|
y = z
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(f"Task {task} not supported.")
|
||||||
|
cell_ids = targets["cell_ids"].to_series()
|
||||||
|
geometries = self.geometries.loc[cell_ids]
|
||||||
|
return gpd.GeoDataFrame(
|
||||||
|
{
|
||||||
|
"cell_id": cell_ids,
|
||||||
|
"geometry": geometries,
|
||||||
|
"y": y,
|
||||||
|
"z": z,
|
||||||
|
}
|
||||||
|
).set_index("cell_id")
|
||||||
|
|
||||||
|
# @stopwatch.f("Reading member", print_kwargs=["member", "stage", "lazy"])
|
||||||
|
def read_member(self, member: L2SourceDataset, cell_ids: pd.Series | None = None, lazy: bool = False) -> xr.Dataset: # noqa: C901
|
||||||
|
"""Read a single member (source) of the Ensemble and applies filters based on the ensemble configuration.
|
||||||
|
|
||||||
|
When `lazy` is False, the data is actually read into memory, which can take some time (seconds).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member (L2SourceDataset): The member
|
||||||
|
cell_ids (pd.Series | None, optional): The cell IDs to read. If None, all cell IDs are read.
|
||||||
|
Defaults to None.
|
||||||
|
lazy (bool, optional): Whether to not load the data and instead return a lazy dataset. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xr.Dataset: Xarray Dataset with the member data for the given cell IDs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
match member:
|
||||||
|
case "AlphaEarth":
|
||||||
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
||||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
case "ERA5-yearly" | "ERA5-seasonal" | "ERA5-shoulder":
|
||||||
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
||||||
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level)
|
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level)
|
||||||
elif member == "ArcticDEM":
|
case "ArcticDEM":
|
||||||
store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
|
store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
|
||||||
else:
|
case _:
|
||||||
raise NotImplementedError(f"Member {member} not implemented.")
|
raise NotImplementedError(f"Member {member} not implemented.")
|
||||||
|
|
||||||
ds = xr.open_zarr(store, consolidated=False)
|
ds = xr.open_zarr(store, consolidated=False)
|
||||||
|
|
||||||
# Apply variable filters
|
# Apply variable and dimension filters
|
||||||
if member in self.variable_filters:
|
if member in self.variable_filters:
|
||||||
assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, (
|
|
||||||
f"Invalid variable filter for {member=}: {self.variable_filters[member]}"
|
|
||||||
" Variable filter values must be a list with one or more entries."
|
|
||||||
)
|
|
||||||
ds = ds[self.variable_filters[member]]
|
ds = ds[self.variable_filters[member]]
|
||||||
|
|
||||||
# Apply dimension filters
|
|
||||||
if member in self.dimension_filters:
|
if member in self.dimension_filters:
|
||||||
for dim, values in self.dimension_filters[member].items():
|
ds = ds.sel(self.dimension_filters[member])
|
||||||
assert isinstance(values, list) and len(values) >= 1, (
|
|
||||||
f"Invalid dimension filter for {dim=}: {values}"
|
|
||||||
" Dimension filter values must be a list with one or more entries."
|
|
||||||
)
|
|
||||||
ds = ds.sel({dim: values})
|
|
||||||
|
|
||||||
# Delete all coordinates which are not in the dimension
|
# Delete all coordinates which are not in the dimension
|
||||||
for coord in ds.coords:
|
ds = ds.drop_vars([coord for coord in ds.coords if coord not in ds.dims])
|
||||||
if coord not in ds.dims:
|
|
||||||
ds = ds.drop_vars(coord)
|
|
||||||
|
|
||||||
# Only load target cell ids
|
# Only load target cell ids
|
||||||
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(targets["cell_id"].values))
|
if cell_ids is None:
|
||||||
|
cell_ids = self.cell_ids
|
||||||
|
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(cell_ids.to_numpy()))
|
||||||
ds = ds.sel(cell_ids=list(intersecting_cell_ids))
|
ds = ds.sel(cell_ids=list(intersecting_cell_ids))
|
||||||
|
|
||||||
|
# Unstack era5 data if needed
|
||||||
|
if member.startswith("ERA5"):
|
||||||
|
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
||||||
|
ds = _unstack_era5_time(ds, era5_agg)
|
||||||
|
|
||||||
|
# Apply the temporal mode
|
||||||
|
match (member.split("-"), self.temporal_mode):
|
||||||
|
case (["ERA5", _] | ["AlphaEarth"], "synopsis"):
|
||||||
|
ds_mean = ds.mean(dim="year")
|
||||||
|
ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
|
||||||
|
# Rename all cols from "{var}_polyfit_coefficients" to "{var}_trend"
|
||||||
|
ds_trend = ds_trend.rename(
|
||||||
|
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars}
|
||||||
|
)
|
||||||
|
ds = xr.merge([ds_mean, ds_trend])
|
||||||
|
case (["ArcticDEM"], "synopsis"):
|
||||||
|
pass # No temporal dimension
|
||||||
|
case (_, int() as year):
|
||||||
|
ds = ds.sel(year=year, drop=True)
|
||||||
|
case (_, "feature"):
|
||||||
|
pass
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.")
|
||||||
|
|
||||||
# Actually read data into memory
|
# Actually read data into memory
|
||||||
if not lazy:
|
if not lazy:
|
||||||
ds.load()
|
ds.load()
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def _read_target(self) -> gpd.GeoDataFrame:
|
def make_features(
|
||||||
if self.target == "darts_rts":
|
self,
|
||||||
target_store = entropice.utils.paths.get_darts_rts_file(grid=self.grid, level=self.level)
|
cell_ids: pd.Series | None = None,
|
||||||
elif self.target == "darts_mllabels":
|
cache_mode: Literal["none", "read", "overwrite"] = "none",
|
||||||
target_store = entropice.utils.paths.get_darts_rts_file(grid=self.grid, level=self.level, labels=True)
|
) -> pd.DataFrame:
|
||||||
else:
|
"""Create a feature DataFrame for the given temporal task and cell IDs.
|
||||||
raise NotImplementedError(f"Target {self.target} not implemented.")
|
|
||||||
targets = gpd.read_parquet(target_store)
|
|
||||||
|
|
||||||
# Filter to coverage
|
This function reads all members of the ensemble, prepares their features based on the temporal task,
|
||||||
if self.filter_target:
|
and combines them into a single DataFrame (no geometry).
|
||||||
targets = targets[targets[self.filter_target]]
|
It is quite computation intensive (seconds to minutes), depending on the configuration and number of cell IDs.
|
||||||
# Convert hex cell_id to int
|
To speed up repeated calls, a caching mechanism is implemented.
|
||||||
if self.grid == "hex":
|
|
||||||
targets["cell_id"] = targets["cell_id"].apply(lambda x: int(x, 16))
|
|
||||||
|
|
||||||
# Add the lat / lon of the cell centers
|
The indented use for this function is solely the creation of training and inference datasets.
|
||||||
if self.add_lonlat:
|
For visualization purposes it is recommended to use the `read_member` with `lazy=True` functions directly.
|
||||||
targets["lon"] = targets.geometry.centroid.x
|
|
||||||
targets["lat"] = targets.geometry.centroid.y
|
|
||||||
|
|
||||||
return targets
|
Args:
|
||||||
|
cell_ids (pd.Series | None, optional): The cell IDs to read. If None, all cell IDs are read.
|
||||||
|
Defaults to None.
|
||||||
|
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode.
|
||||||
|
"none": No caching.
|
||||||
|
"read": Read from cache if exists, otherwise create and save to cache.
|
||||||
|
"overwrite": Always create and save to cache, overwriting existing cache.
|
||||||
|
Defaults to "none".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: DataFrame with features for the given cell IDs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if cell_ids is None:
|
||||||
|
cell_ids = self.cell_ids
|
||||||
|
|
||||||
|
# Caching mechanism
|
||||||
|
if cache_mode != "none":
|
||||||
|
cells_hash = _cell_ids_hash(cell_ids)
|
||||||
|
cache_file = entropice.utils.paths.get_features_cache(ensemble_id=self.id(), cells_hash=cells_hash)
|
||||||
|
if cache_mode == "read" and cache_file.exists():
|
||||||
|
with stopwatch("Loading features from cache"):
|
||||||
|
return pd.read_parquet(cache_file)
|
||||||
|
|
||||||
|
# Create features
|
||||||
|
dataset = self.read_grid().loc[cell_ids].drop(columns=["geometry"])
|
||||||
|
member_dfs = []
|
||||||
|
for member in self.members:
|
||||||
|
match member.split("-"):
|
||||||
|
case ["ERA5", era5_agg]:
|
||||||
|
member_dfs.append(self._prep_era5(cell_ids, era5_agg))
|
||||||
|
case ["AlphaEarth"]:
|
||||||
|
member_dfs.append(self._prep_embeddings(cell_ids))
|
||||||
|
case ["ArcticDEM"]:
|
||||||
|
member_dfs.append(self._prep_arcticdem(cell_ids))
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(f"Member {member} not implemented.")
|
||||||
|
with stopwatch("Joining datasets"):
|
||||||
|
dataset = dataset.join(member_dfs)
|
||||||
|
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||||
|
|
||||||
|
if cache_mode in ["read", "overwrite"]:
|
||||||
|
with stopwatch("Saving features to cache"):
|
||||||
|
dataset.to_parquet(cache_file)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
# @stopwatch.f("Preparing ERA5", print_kwargs=["stage", "temporal"])
|
||||||
def _prep_era5(
|
def _prep_era5(
|
||||||
self,
|
self,
|
||||||
targets: gpd.GeoDataFrame,
|
cell_ids: pd.Series,
|
||||||
temporal: Literal["yearly", "seasonal", "shoulder"],
|
era5_agg: Literal["yearly", "seasonal", "shoulder"],
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
era5 = self._read_member("ERA5-" + temporal, targets)
|
era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=False)
|
||||||
|
era5_df = _collapse_to_dataframe(era5)
|
||||||
if len(era5["cell_ids"]) == 0:
|
era5_df.columns = [f"era5_{col}" for col in era5_df.columns]
|
||||||
# No data for these cells - create empty DataFrame with expected columns
|
|
||||||
# Use the Dataset metadata to determine column structure
|
|
||||||
variables = list(era5.data_vars)
|
|
||||||
times = era5.coords["time"].to_numpy()
|
|
||||||
time_df = pd.DataFrame({"time": times})
|
|
||||||
time_df.index = pd.DatetimeIndex(times)
|
|
||||||
tempus = _get_era5_tempus(time_df, temporal)
|
|
||||||
unique_tempus = tempus.unique()
|
|
||||||
|
|
||||||
if "aggregations" in era5.dims:
|
|
||||||
aggs_list = era5.coords["aggregations"].to_numpy()
|
|
||||||
expected_cols = [
|
|
||||||
f"era5_{var}_{t}_{agg}" for var, t, agg in product(variables, unique_tempus, aggs_list)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
expected_cols = [f"era5_{var}_{t}" for var, t in product(variables, unique_tempus)]
|
|
||||||
|
|
||||||
return pd.DataFrame(
|
|
||||||
index=targets["cell_id"].values,
|
|
||||||
columns=expected_cols,
|
|
||||||
dtype=float,
|
|
||||||
)
|
|
||||||
era5_df = era5.to_dataframe()
|
|
||||||
era5_df["t"] = _get_era5_tempus(era5_df, temporal)
|
|
||||||
if "aggregations" not in era5.dims:
|
|
||||||
era5_df = era5_df.pivot_table(index="cell_ids", columns="t")
|
|
||||||
era5_df.columns = [f"era5_{var}_{t}" for var, t in era5_df.columns]
|
|
||||||
else:
|
|
||||||
era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"])
|
|
||||||
era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
|
|
||||||
# Ensure all target cell_ids are present, fill missing with NaN
|
# Ensure all target cell_ids are present, fill missing with NaN
|
||||||
era5_df = era5_df.reindex(targets["cell_id"].values, fill_value=np.nan)
|
era5_df = era5_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
|
||||||
return era5_df
|
return era5_df
|
||||||
|
|
||||||
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
# @stopwatch.f("Preparing ALphaEarth Embeddings", print_kwargs=["stage"])
|
||||||
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
|
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
|
||||||
|
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
|
||||||
if len(embeddings["cell_ids"]) == 0:
|
embeddings_df = _collapse_to_dataframe(embeddings)
|
||||||
# No data for these cells - create empty DataFrame with expected columns
|
embeddings_df.columns = [f"embeddings_{col}" for col in embeddings_df.columns]
|
||||||
# Use the Dataset metadata to determine column structure
|
|
||||||
years = embeddings.coords["year"].to_numpy()
|
|
||||||
aggs = embeddings.coords["agg"].to_numpy()
|
|
||||||
bands = embeddings.coords["band"].to_numpy()
|
|
||||||
expected_cols = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in product(years, aggs, bands)]
|
|
||||||
return pd.DataFrame(
|
|
||||||
index=targets["cell_id"].values,
|
|
||||||
columns=expected_cols,
|
|
||||||
dtype=float,
|
|
||||||
)
|
|
||||||
embeddings_df = embeddings.to_dataframe(name="value")
|
|
||||||
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
|
|
||||||
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
|
||||||
# Ensure all target cell_ids are present, fill missing with NaN
|
# Ensure all target cell_ids are present, fill missing with NaN
|
||||||
embeddings_df = embeddings_df.reindex(targets["cell_id"].values, fill_value=np.nan)
|
embeddings_df = embeddings_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
|
||||||
return embeddings_df
|
return embeddings_df
|
||||||
|
|
||||||
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
# @stopwatch.f("Preparing ArcticDEM", print_kwargs=["stage"])
|
||||||
arcticdem = self._read_member("ArcticDEM", targets)
|
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
|
||||||
|
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
|
||||||
if len(arcticdem["cell_ids"]) == 0:
|
if len(arcticdem["cell_ids"]) == 0:
|
||||||
# No data for these cells - create empty DataFrame with expected columns
|
# No data for these cells - create empty DataFrame with expected columns
|
||||||
# Use the Dataset metadata to determine column structure
|
# Use the Dataset metadata to determine column structure
|
||||||
variables = list(arcticdem.data_vars)
|
variables = list(arcticdem.data_vars)
|
||||||
aggs = arcticdem.coords["aggregations"].to_numpy()
|
aggs = arcticdem.coords["aggregations"].to_numpy()
|
||||||
expected_cols = [f"arcticdem_{var}_{agg}" for var, agg in product(variables, aggs)]
|
expected_cols = [f"arcticdem_{var}_{agg}" for var, agg in product(variables, aggs)]
|
||||||
return pd.DataFrame(
|
return pd.DataFrame(index=cell_ids.to_numpy(), columns=expected_cols, dtype=float)
|
||||||
index=targets["cell_id"].values,
|
|
||||||
columns=expected_cols,
|
|
||||||
dtype=float,
|
|
||||||
)
|
|
||||||
arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations")
|
arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations")
|
||||||
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
|
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
|
||||||
# Ensure all target cell_ids are present, fill missing with NaN
|
# Ensure all target cell_ids are present, fill missing with NaN
|
||||||
arcticdem_df = arcticdem_df.reindex(targets["cell_id"].values, fill_value=np.nan)
|
arcticdem_df = arcticdem_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
|
||||||
return arcticdem_df
|
return arcticdem_df
|
||||||
|
|
||||||
def get_stats(self) -> DatasetStats:
|
def create_inference_df(
|
||||||
"""Get dataset statistics.
|
self,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
cache_mode: Literal["none", "overwrite", "read"] = "read",
|
||||||
|
) -> Generator[pd.DataFrame]:
|
||||||
|
"""Create an inference feature set generator.
|
||||||
|
|
||||||
Returns:
|
This function creates features for all cell IDs in batches.
|
||||||
DatasetStats: Dictionary containing target stats, member stats, and total features count.
|
If `batch_size` is None or greater than the number of cell IDs, all data is returned in a single batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int | None, optional): The batch size. If None, all data is returned in a single batch.
|
||||||
|
Defaults to None.
|
||||||
|
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
|
||||||
|
"none": No caching.
|
||||||
|
"read": Read from cache if exists, otherwise create and save to cache.
|
||||||
|
"overwrite": Always create and save to cache, overwriting existing cache.
|
||||||
|
Defaults to "none".
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Generator[pd.DataFrame]: Generator yielding DataFrames with features for inference.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
targets = self._read_target()
|
all_cell_ids = self.cell_ids
|
||||||
stats: DatasetStats = {
|
if batch_size is None or batch_size >= len(all_cell_ids):
|
||||||
"target": self.target,
|
yield self.make_features(cell_ids=all_cell_ids, cache_mode=cache_mode)
|
||||||
"num_target_samples": len(targets),
|
|
||||||
"members": {},
|
|
||||||
"total_features": 2 if self.add_lonlat else 0, # Lat and Lon
|
|
||||||
}
|
|
||||||
|
|
||||||
for member in self.members:
|
|
||||||
ds = self._read_member(member, targets, lazy=True)
|
|
||||||
n_cols_member = len(ds.data_vars)
|
|
||||||
for dim in ds.sizes:
|
|
||||||
if dim != "cell_ids":
|
|
||||||
n_cols_member *= ds.sizes[dim]
|
|
||||||
|
|
||||||
stats["members"][member] = {
|
|
||||||
"variables": list(ds.data_vars),
|
|
||||||
"num_variables": len(ds.data_vars),
|
|
||||||
"dimensions": dict(ds.sizes),
|
|
||||||
"coordinates": list(ds.coords),
|
|
||||||
"num_features": n_cols_member,
|
|
||||||
}
|
|
||||||
stats["total_features"] += n_cols_member
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
def print_stats(self):
|
|
||||||
stats = self.get_stats()
|
|
||||||
print(f"=== Target: {stats['target']}")
|
|
||||||
print(f"\tNumber of target samples: {stats['num_target_samples']}")
|
|
||||||
|
|
||||||
for member, member_stats in stats["members"].items():
|
|
||||||
print(f"=== Member: {member}")
|
|
||||||
print(f"\tVariables ({member_stats['num_variables']}): {member_stats['variables']}")
|
|
||||||
print(f"\tDimensions: {member_stats['dimensions']}")
|
|
||||||
print(f"\tCoordinates: {member_stats['coordinates']}")
|
|
||||||
print(f"\tNumber of features from member: {member_stats['num_features']}")
|
|
||||||
|
|
||||||
print(f"=== Total number of features in dataset: {stats['total_features']}")
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
filter_target_col: str | None = None,
|
|
||||||
cache_mode: Literal["n", "o", "r"] = "r",
|
|
||||||
) -> gpd.GeoDataFrame:
|
|
||||||
# n: no cache, o: overwrite cache, r: read cache if exists
|
|
||||||
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
|
|
||||||
if cache_mode == "r" and cache_file.exists():
|
|
||||||
with stopwatch("Loading dataset from cache"):
|
|
||||||
dataset = gpd.read_parquet(cache_file)
|
|
||||||
print(
|
|
||||||
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
|
|
||||||
f" and {len(dataset.columns)} features."
|
|
||||||
)
|
|
||||||
return dataset
|
|
||||||
with stopwatch("Reading target"):
|
|
||||||
targets = self._read_target()
|
|
||||||
if filter_target_col is not None:
|
|
||||||
targets = targets.loc[targets[filter_target_col]]
|
|
||||||
print(f"Read and filtered target dataset. ({len(targets)} samples)")
|
|
||||||
with stopwatch("Preparing member datasets"):
|
|
||||||
member_dfs = []
|
|
||||||
for member in self.members:
|
|
||||||
if member.startswith("ERA5"):
|
|
||||||
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
|
||||||
member_dfs.append(self._prep_era5(targets, era5_agg))
|
|
||||||
elif member == "AlphaEarth":
|
|
||||||
member_dfs.append(self._prep_embeddings(targets))
|
|
||||||
elif member == "ArcticDEM":
|
|
||||||
member_dfs.append(self._prep_arcticdem(targets))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Member {member} not implemented.")
|
|
||||||
print("Prepared all member datasets. Joining...")
|
|
||||||
with stopwatch("Joining datasets"):
|
|
||||||
dataset = targets.set_index("cell_id").join(member_dfs)
|
|
||||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
|
||||||
print("Joining complete.")
|
|
||||||
|
|
||||||
if cache_mode in ["o", "r"]:
|
|
||||||
with stopwatch("Saving dataset to cache"):
|
|
||||||
dataset.to_parquet(cache_file)
|
|
||||||
print(f"Saved dataset to cache at {cache_file}.")
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
def create_batches(
|
|
||||||
self,
|
|
||||||
batch_size: int,
|
|
||||||
filter_target_col: str | None = None,
|
|
||||||
cache_mode: Literal["n", "o", "r"] = "r",
|
|
||||||
) -> Generator[pd.DataFrame]:
|
|
||||||
targets = self._read_target()
|
|
||||||
if len(targets) == 0:
|
|
||||||
raise ValueError("No target samples found.")
|
|
||||||
elif len(targets) < batch_size:
|
|
||||||
yield self.create(filter_target_col=filter_target_col, cache_mode=cache_mode)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if filter_target_col is not None:
|
for i in range(0, len(all_cell_ids), batch_size):
|
||||||
targets = targets.loc[targets[filter_target_col]]
|
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
|
||||||
|
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
|
||||||
|
|
||||||
for i in range(0, len(targets), batch_size):
|
def create_training_df(
|
||||||
# n: no cache, o: overwrite cache, r: read cache if exists
|
|
||||||
cache_file = entropice.utils.paths.get_dataset_cache(
|
|
||||||
self.id(), subset=filter_target_col, batch=(i, i + batch_size)
|
|
||||||
)
|
|
||||||
if cache_mode == "r" and cache_file.exists():
|
|
||||||
dataset = gpd.read_parquet(cache_file)
|
|
||||||
print(
|
|
||||||
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
|
|
||||||
f" and {len(dataset.columns)} features."
|
|
||||||
)
|
|
||||||
yield dataset
|
|
||||||
else:
|
|
||||||
targets_batch = targets.iloc[i : i + batch_size]
|
|
||||||
|
|
||||||
member_dfs = []
|
|
||||||
for member in self.members:
|
|
||||||
if member.startswith("ERA5"):
|
|
||||||
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
|
|
||||||
member_dfs.append(self._prep_era5(targets_batch, era5_agg))
|
|
||||||
elif member == "AlphaEarth":
|
|
||||||
member_dfs.append(self._prep_embeddings(targets_batch))
|
|
||||||
elif member == "ArcticDEM":
|
|
||||||
member_dfs.append(self._prep_arcticdem(targets_batch))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Member {member} not implemented.")
|
|
||||||
|
|
||||||
dataset = targets_batch.set_index("cell_id").join(member_dfs)
|
|
||||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
|
||||||
if cache_mode in ["o", "r"]:
|
|
||||||
dataset.to_parquet(cache_file)
|
|
||||||
print(f"Saved dataset to cache at {cache_file}.")
|
|
||||||
yield dataset
|
|
||||||
|
|
||||||
def _cat_and_split(
|
|
||||||
self,
|
self,
|
||||||
dataset: gpd.GeoDataFrame,
|
|
||||||
task: Task,
|
task: Task,
|
||||||
device: Literal["cpu", "cuda", "torch"],
|
cache_mode: Literal["none", "overwrite", "read"] = "read",
|
||||||
) -> CategoricalTrainingDataset:
|
) -> pd.DataFrame:
|
||||||
taskcol = self.taskcol(task)
|
"""Create a simple DataFrame for training with e.g. autogluon.
|
||||||
|
|
||||||
valid_labels = dataset[taskcol].notna()
|
This function creates a DataFrame with features and target labels (called "y") for training.
|
||||||
|
The resulting dataframe contains no geometry information and drops all samples with NaN values.
|
||||||
|
Does not split into train and test sets, also does not convert to arrays.
|
||||||
|
For more advanced use cases, use `create_training_set`.
|
||||||
|
|
||||||
cols_to_drop = {"geometry", taskcol, self.covcol}
|
Args:
|
||||||
cols_to_drop |= {
|
task (Task): The task.
|
||||||
col
|
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
|
||||||
for col in dataset.columns
|
"none": No caching.
|
||||||
if col.startswith("dartsml_" if self.target == "darts_mllabels" else "darts_")
|
"read": Read from cache if exists, otherwise create and save to cache.
|
||||||
}
|
"overwrite": Always create and save to cache, overwriting existing cache.
|
||||||
|
Defaults to "none".
|
||||||
|
|
||||||
model_inputs = dataset.drop(columns=cols_to_drop)
|
Returns:
|
||||||
# Assert that no column in all-nan
|
pd.DataFrame: The training DataFrame.
|
||||||
assert not model_inputs.isna().all("index").any(), "Some input columns are all NaN"
|
|
||||||
# Get valid inputs (rows)
|
|
||||||
valid_inputs = model_inputs.notna().all("columns")
|
|
||||||
|
|
||||||
dataset = dataset.loc[valid_labels & valid_inputs]
|
"""
|
||||||
model_inputs = model_inputs.loc[valid_labels & valid_inputs]
|
targets = self.get_targets(task)
|
||||||
model_labels = dataset[taskcol]
|
assert len(targets) > 0, "No target samples found."
|
||||||
|
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode)
|
||||||
|
dataset = targets[["y"]].join(features).dropna()
|
||||||
|
assert len(dataset) > 0, "No valid samples found after joining features and targets."
|
||||||
|
return dataset
|
||||||
|
|
||||||
if task == "binary":
|
def create_training_set(
|
||||||
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
|
self,
|
||||||
elif task == "count":
|
task: Task,
|
||||||
binned = bin_values(model_labels.astype(int), task=task)
|
device: Literal["cpu", "cuda", "torch"] = "cpu",
|
||||||
elif task == "density":
|
cache_mode: Literal["none", "overwrite", "read"] = "read",
|
||||||
binned = bin_values(model_labels, task=task)
|
) -> TrainingSet:
|
||||||
else:
|
"""Create a full training set for model training.
|
||||||
raise ValueError("Invalid task.")
|
|
||||||
|
This function creates a full training set with features and target labels,
|
||||||
|
splits it into train and test sets, converts to arrays and moves to the specified device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Task): The task.
|
||||||
|
device (Literal["cpu", "cuda", "torch"], optional): The device to move the data to. Defaults to "cpu".
|
||||||
|
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
|
||||||
|
"none": No caching.
|
||||||
|
"read": Read from cache if exists, otherwise create and save to cache.
|
||||||
|
"overwrite": Always create and save to cache, overwriting existing cache.
|
||||||
|
Defaults to "none".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrainingSet: The training set.
|
||||||
|
Contains a lot of useful information for training and evaluation.
|
||||||
|
|
||||||
|
"""
|
||||||
|
targets = self.get_targets(task)
|
||||||
|
assert len(targets) > 0, "No target samples found."
|
||||||
|
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode).dropna()
|
||||||
|
assert len(features) > 0, "No valid features found after dropping NaNs."
|
||||||
|
assert not features.isna().all("index").any(), "Some feature columns are all NaN"
|
||||||
|
|
||||||
# Create train / test split
|
# Create train / test split
|
||||||
train_idx, test_idx = train_test_split(dataset.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True)
|
train_idx, test_idx = train_test_split(features.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True)
|
||||||
split = pd.Series(index=dataset.index, dtype=object)
|
split = pd.Series(index=features.index, dtype=object)
|
||||||
split.loc[train_idx] = "train"
|
split.loc[train_idx] = "train"
|
||||||
split.loc[test_idx] = "test"
|
split.loc[test_idx] = "test"
|
||||||
split = split.astype("category")
|
split = split.astype("category")
|
||||||
|
|
||||||
X_train = model_inputs.loc[train_idx].to_numpy(dtype="float64")
|
X_train = features.loc[train_idx].to_numpy(dtype="float64")
|
||||||
X_test = model_inputs.loc[test_idx].to_numpy(dtype="float64")
|
X_test = features.loc[test_idx].to_numpy(dtype="float64")
|
||||||
y_train = binned.loc[train_idx].cat.codes.to_numpy(dtype="int64")
|
if task in ["binary", "count_regimes", "density_regimes"]:
|
||||||
y_test = binned.loc[test_idx].cat.codes.to_numpy(dtype="int64")
|
y_train = targets.loc[train_idx, "y"].cat.codes.to_numpy(dtype="int64")
|
||||||
|
y_test = targets.loc[test_idx, "y"].cat.codes.to_numpy(dtype="int64")
|
||||||
|
else:
|
||||||
|
y_train = targets.loc[train_idx, "y"].to_numpy(dtype="float64")
|
||||||
|
y_test = targets.loc[test_idx, "y"].to_numpy(dtype="float64")
|
||||||
|
|
||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
X_train = cp.asarray(X_train)
|
X_train = cp.asarray(X_train)
|
||||||
|
|
@ -578,28 +642,13 @@ class DatasetEnsemble:
|
||||||
y_test = torch.from_numpy(y_test).to(device=torch_device)
|
y_test = torch.from_numpy(y_test).to(device=torch_device)
|
||||||
print(f"Using torch device: {torch.cuda.get_device_name(X_train.device) if X_train.is_cuda else 'cpu'}")
|
print(f"Using torch device: {torch.cuda.get_device_name(X_train.device) if X_train.is_cuda else 'cpu'}")
|
||||||
else:
|
else:
|
||||||
assert device == "cpu", "Invalid device specified."
|
assert device == "cpu", f"Invalid {device=} specified."
|
||||||
|
|
||||||
return CategoricalTrainingDataset(
|
return TrainingSet(
|
||||||
dataset=dataset.to_crs("EPSG:4326"),
|
targets=targets.loc[features.index].to_crs("EPSG:4326"),
|
||||||
X=DatasetInputs(data=model_inputs, train=X_train, test=X_test),
|
features=features,
|
||||||
y=DatasetLabels(binned=binned, train=y_train, test=y_test, raw_values=model_labels),
|
X=SplittedArrays(train=X_train, test=X_test),
|
||||||
z=model_labels,
|
y=SplittedArrays(train=y_train, test=y_test),
|
||||||
|
z=targets.loc[features.index, "z"],
|
||||||
split=split,
|
split=split,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_cat_training_dataset(
|
|
||||||
self, task: Task, device: Literal["cpu", "cuda", "torch"]
|
|
||||||
) -> CategoricalTrainingDataset:
|
|
||||||
"""Create a categorical dataset for training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (Task): Task type.
|
|
||||||
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CategoricalTrainingDataset: The prepared categorical training dataset.
|
|
||||||
|
|
||||||
"""
|
|
||||||
dataset = self.create(filter_target_col=self.covcol)
|
|
||||||
return self._cat_and_split(dataset, task, device)
|
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ traceback.install()
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
||||||
|
|
||||||
def open(grid: Grid, level: int):
|
def open(grid: Grid, level: int) -> gpd.GeoDataFrame:
|
||||||
"""Open a saved grid from parquet file.
|
"""Open a saved grid from parquet file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -162,7 +162,7 @@ def create_global_hex_grid(resolution):
|
||||||
grid = gpd.GeoDataFrame(
|
grid = gpd.GeoDataFrame(
|
||||||
{"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list},
|
{"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list},
|
||||||
crs="EPSG:4326",
|
crs="EPSG:4326",
|
||||||
)
|
) # ty:ignore[no-matching-overload]
|
||||||
|
|
||||||
return grid
|
return grid
|
||||||
|
|
||||||
|
|
@ -193,7 +193,7 @@ def create_global_healpix_grid(level: int):
|
||||||
geometry = healpix_ds.dggs.cell_boundaries()
|
geometry = healpix_ds.dggs.cell_boundaries()
|
||||||
|
|
||||||
# Create GeoDataFrame
|
# Create GeoDataFrame
|
||||||
grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326")
|
grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326") # ty:ignore[no-matching-overload]
|
||||||
|
|
||||||
grid["cell_area"] = grid.to_crs("EPSG:3413").geometry.area / 1e6 # Convert to km^2
|
grid["cell_area"] = grid.to_crs("EPSG:3413").geometry.area / 1e6 # Convert to km^2
|
||||||
|
|
||||||
|
|
@ -250,18 +250,18 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={"projection": ccrs.NorthPolarStereo()})
|
fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={"projection": ccrs.NorthPolarStereo()})
|
||||||
ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree())
|
ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree()) # ty:ignore[unresolved-attribute]
|
||||||
|
|
||||||
# Add features
|
# Add features
|
||||||
ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white")
|
ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white") # ty:ignore[unresolved-attribute]
|
||||||
ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey")
|
ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey") # ty:ignore[unresolved-attribute]
|
||||||
ax.add_feature(cfeature.COASTLINE)
|
ax.add_feature(cfeature.COASTLINE) # ty:ignore[unresolved-attribute]
|
||||||
ax.add_feature(cfeature.BORDERS, linestyle=":")
|
ax.add_feature(cfeature.BORDERS, linestyle=":") # ty:ignore[unresolved-attribute]
|
||||||
ax.add_feature(cfeature.LAKES, alpha=0.5)
|
ax.add_feature(cfeature.LAKES, alpha=0.5) # ty:ignore[unresolved-attribute]
|
||||||
ax.add_feature(cfeature.RIVERS)
|
ax.add_feature(cfeature.RIVERS) # ty:ignore[unresolved-attribute]
|
||||||
|
|
||||||
# Add gridlines
|
# Add gridlines
|
||||||
gl = ax.gridlines(draw_labels=True)
|
gl = ax.gridlines(draw_labels=True) # ty:ignore[unresolved-attribute]
|
||||||
gl.top_labels = False
|
gl.top_labels = False
|
||||||
gl.right_labels = False
|
gl.right_labels = False
|
||||||
|
|
||||||
|
|
@ -292,7 +292,7 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
|
||||||
verts = np.vstack([np.sin(theta), np.cos(theta)]).T
|
verts = np.vstack([np.sin(theta), np.cos(theta)]).T
|
||||||
circle = mpath.Path(verts * radius + center)
|
circle = mpath.Path(verts * radius + center)
|
||||||
|
|
||||||
ax.set_boundary(circle, transform=ax.transAxes)
|
ax.set_boundary(circle, transform=ax.transAxes) # ty:ignore[unresolved-attribute]
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,9 @@ DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcodin
|
||||||
|
|
||||||
GRIDS_DIR = DATA_DIR / "grids"
|
GRIDS_DIR = DATA_DIR / "grids"
|
||||||
FIGURES_DIR = Path("figures")
|
FIGURES_DIR = Path("figures")
|
||||||
RTS_DIR = DATA_DIR / "darts-rts"
|
DARTS_V1_DIR = DATA_DIR / "darts-v1"
|
||||||
RTS_LABELS_DIR = DATA_DIR / "darts-rts-mllabels"
|
DARTS_V2_DIR = DATA_DIR / "darts-v2"
|
||||||
|
DARTS_MLLABELS_DIR = DATA_DIR / "darts-mllabels"
|
||||||
ERA5_DIR = DATA_DIR / "era5"
|
ERA5_DIR = DATA_DIR / "era5"
|
||||||
ARCTICDEM_DIR = DATA_DIR / "arcticdem"
|
ARCTICDEM_DIR = DATA_DIR / "arcticdem"
|
||||||
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
||||||
|
|
@ -27,7 +28,9 @@ RESULTS_DIR = DATA_DIR / "results"
|
||||||
|
|
||||||
GRIDS_DIR.mkdir(parents=True, exist_ok=True)
|
GRIDS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
RTS_DIR.mkdir(parents=True, exist_ok=True)
|
DARTS_V1_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
DARTS_V2_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
DARTS_MLLABELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
ERA5_DIR.mkdir(parents=True, exist_ok=True)
|
ERA5_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True)
|
ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -39,10 +42,6 @@ DATASET_ENSEMBLES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
||||||
|
|
||||||
dartsl2_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
|
||||||
dartsl2_cov_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
|
||||||
darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_gridname(grid: Grid, level: int) -> str:
|
def _get_gridname(grid: Grid, level: int) -> str:
|
||||||
return f"permafrost_{grid}{level}"
|
return f"permafrost_{grid}{level}"
|
||||||
|
|
@ -60,13 +59,18 @@ def get_grid_viz_file(grid: Grid, level: int) -> Path:
|
||||||
return vizfile
|
return vizfile
|
||||||
|
|
||||||
|
|
||||||
def get_darts_rts_file(grid: Grid, level: int, labels: bool = False) -> Path:
|
def get_darts_file(grid: Grid, level: int, version: Literal["v1", "v1-l3", "v2", "v2-l3", "mllabels"]) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
if labels:
|
match version:
|
||||||
rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet"
|
case "v1" | "v1-l3":
|
||||||
else:
|
darts_file = DARTS_V1_DIR / f"{gridname}_darts-{version}.zarr"
|
||||||
rtsfile = RTS_DIR / f"{gridname}_darts.parquet"
|
case "v2" | "v2-l3":
|
||||||
return rtsfile
|
darts_file = DARTS_V2_DIR / f"{gridname}_darts-{version}.zarr"
|
||||||
|
case "mllabels":
|
||||||
|
darts_file = DARTS_MLLABELS_DIR / f"{gridname}_darts-{version}.zarr"
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown DARTS version: {version}")
|
||||||
|
return darts_file
|
||||||
|
|
||||||
|
|
||||||
def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path:
|
def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path:
|
||||||
|
|
@ -85,13 +89,20 @@ def get_era5_stores(
|
||||||
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
|
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
|
||||||
grid: Grid | None = None,
|
grid: Grid | None = None,
|
||||||
level: int | None = None,
|
level: int | None = None,
|
||||||
|
temporal: Literal["synopsis"] | None = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
if grid is None or level is None:
|
pdir = ERA5_DIR
|
||||||
(ERA5_DIR / "intermediate").mkdir(parents=True, exist_ok=True)
|
if temporal is not None:
|
||||||
return ERA5_DIR / "intermediate" / f"{agg}_climate.zarr"
|
agg += f"_{temporal}" # ty:ignore[invalid-assignment]
|
||||||
|
fname = f"{agg}_climate.zarr"
|
||||||
|
|
||||||
|
if grid is None or level is None:
|
||||||
|
pdir = pdir / "intermediate"
|
||||||
|
pdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
aligned_path = ERA5_DIR / f"{gridname}_{agg}_climate.zarr"
|
fname = f"{gridname}_{fname}"
|
||||||
|
aligned_path = pdir / fname
|
||||||
return aligned_path
|
return aligned_path
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -122,6 +133,12 @@ def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int
|
||||||
return cache_file
|
return cache_file
|
||||||
|
|
||||||
|
|
||||||
|
def get_features_cache(ensemble_id: str, cells_hash: str) -> Path:
|
||||||
|
cache_dir = DATASET_ENSEMBLES_DIR / "cache" / "features" / ensemble_id
|
||||||
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
return cache_dir / f"cells_{cells_hash}.parquet"
|
||||||
|
|
||||||
|
|
||||||
def get_cv_results_dir(
|
def get_cv_results_dir(
|
||||||
name: str,
|
name: str,
|
||||||
grid: Grid,
|
grid: Grid,
|
||||||
|
|
@ -133,3 +150,15 @@ def get_cv_results_dir(
|
||||||
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
|
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
return results_dir
|
return results_dir
|
||||||
|
|
||||||
|
|
||||||
|
def get_autogluon_results_dir(
|
||||||
|
grid: Grid,
|
||||||
|
level: int,
|
||||||
|
task: Task,
|
||||||
|
) -> Path:
|
||||||
|
gridname = _get_gridname(grid, level)
|
||||||
|
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
results_dir = RESULTS_DIR / f"{gridname}_autogluon_{now}_{task}"
|
||||||
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
return results_dir
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,14 @@ type GridLevel = Literal[
|
||||||
"healpix9",
|
"healpix9",
|
||||||
"healpix10",
|
"healpix10",
|
||||||
]
|
]
|
||||||
type TargetDataset = Literal["darts_rts", "darts_mllabels"]
|
type TargetDataset = Literal["darts_v2", "darts_v1", "darts_mllabels", "darts_rts"]
|
||||||
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
||||||
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
||||||
type Task = Literal["binary", "count", "density"]
|
type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||||
|
# TODO: Consider implementing a "timeseries" temporal mode
|
||||||
|
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
|
||||||
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
||||||
|
type Stage = Literal["train", "inference", "visualization"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -70,8 +73,9 @@ class GridConfig:
|
||||||
|
|
||||||
|
|
||||||
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
|
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
|
||||||
all_tasks: list[Task] = ["binary", "count", "density"]
|
all_tasks: list[Task] = ["binary", "count_regimes", "density_regimes", "count", "density"]
|
||||||
all_target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
|
all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
|
||||||
|
all_target_datasets: list[TargetDataset] = ["darts_mllabels", "darts_rts"]
|
||||||
all_l2_source_datasets: list[L2SourceDataset] = [
|
all_l2_source_datasets: list[L2SourceDataset] = [
|
||||||
"ArcticDEM",
|
"ArcticDEM",
|
||||||
"ERA5-shoulder",
|
"ERA5-shoulder",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue