Add first training
This commit is contained in:
parent
ad3d7aae73
commit
3e0e6e0d2d
11 changed files with 5368 additions and 83 deletions
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
# SCM syntax highlighting & preventing 3-way merges
|
||||||
|
pixi.lock merge=binary linguist-language=YAML linguist-generated=true
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -21,3 +21,7 @@ pg.ipynb
|
||||||
playground.ipynb
|
playground.ipynb
|
||||||
*fix*.ipynb
|
*fix*.ipynb
|
||||||
*debug*.ipynb
|
*debug*.ipynb
|
||||||
|
|
||||||
|
# pixi environments
|
||||||
|
.pixi
|
||||||
|
*.egg-info
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [{ name = "Tobias Hölzer", email = "tobiashoelzer@hotmail.com" }]
|
authors = [{ name = "Tobias Hölzer", email = "tobiashoelzer@hotmail.com" }]
|
||||||
requires-python = ">=3.12"
|
# requires-python = ">=3.10,<3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp>=3.12.11",
|
"aiohttp>=3.12.11",
|
||||||
"bokeh>=3.7.3",
|
"bokeh>=3.7.3",
|
||||||
|
|
@ -16,7 +16,7 @@ dependencies = [
|
||||||
"distributed>=2025.5.1",
|
"distributed>=2025.5.1",
|
||||||
"earthengine-api>=1.6.9",
|
"earthengine-api>=1.6.9",
|
||||||
"eemont>=2025.7.1",
|
"eemont>=2025.7.1",
|
||||||
"entropyc",
|
"entropy",
|
||||||
"flox>=0.10.4",
|
"flox>=0.10.4",
|
||||||
"folium>=0.19.7",
|
"folium>=0.19.7",
|
||||||
"geemap>=0.36.3",
|
"geemap>=0.36.3",
|
||||||
|
|
@ -48,7 +48,7 @@ dependencies = [
|
||||||
"xarray>=2025.9.0",
|
"xarray>=2025.9.0",
|
||||||
"xdggs>=0.2.1",
|
"xdggs>=0.2.1",
|
||||||
"xvec>=0.5.1",
|
"xvec>=0.5.1",
|
||||||
"zarr[remote]>=3.1.3",
|
"zarr[remote]>=3.1.3", "geocube>=0.7.1,<0.8",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
@ -56,6 +56,7 @@ create-grid = "entropice.grids:main"
|
||||||
darts = "entropice.darts:main"
|
darts = "entropice.darts:main"
|
||||||
alpha-earth = "entropice.alphaearth:main"
|
alpha-earth = "entropice.alphaearth:main"
|
||||||
era5 = "entropice.era5:cli"
|
era5 = "entropice.era5:cli"
|
||||||
|
train = "entropice.training:cli"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
|
|
@ -66,4 +67,40 @@ package = true
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
|
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
|
||||||
|
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
|
||||||
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
||||||
|
|
||||||
|
[tool.ruff.lint.pyflakes]
|
||||||
|
# Ignore libraries when checking for unused imports
|
||||||
|
allowed-unused-imports = [
|
||||||
|
"hvplot.pandas",
|
||||||
|
"hvplot.xarray",
|
||||||
|
"rioxarray",
|
||||||
|
"odc.geo.xr",
|
||||||
|
"cupy_xarray",
|
||||||
|
"xdggs",
|
||||||
|
"xvec",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pixi.workspace]
|
||||||
|
channels = ["conda-forge"]
|
||||||
|
platforms = ["linux-64"]
|
||||||
|
|
||||||
|
[tool.pixi.activation.env]
|
||||||
|
SCIPY_ARRAY_API = "1"
|
||||||
|
|
||||||
|
[tool.pixi.system-requirements]
|
||||||
|
cuda = "12"
|
||||||
|
|
||||||
|
[tool.pixi.pypi-dependencies]
|
||||||
|
entropice = { path = ".", editable = true }
|
||||||
|
|
||||||
|
[tool.pixi.tasks]
|
||||||
|
|
||||||
|
[tool.pixi.dependencies]
|
||||||
|
pytorch-gpu = ">=2.5.1,<3"
|
||||||
|
cupy = ">=13.6.0,<14"
|
||||||
|
nccl = ">=2.27.7.1,<3"
|
||||||
|
cudnn = ">=9.13.1.26,<10"
|
||||||
|
cusparselt = ">=0.8.1.1,<0.9"
|
||||||
|
cuda-version = "12.1.*"
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,15 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
||||||
grid_gdf[f"darts_{year}_rts_area"] / grid_gdf[f"darts_{year}_covered_area"]
|
grid_gdf[f"darts_{year}_rts_area"] / grid_gdf[f"darts_{year}_covered_area"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply corrections to NaNs
|
||||||
|
covered = ~grid_gdf[f"darts_{year}_coverage"].isnull()
|
||||||
|
grid_gdf.loc[covered, f"darts_{year}_rts_count"] = grid_gdf.loc[covered, f"darts_{year}_rts_count"].fillna(
|
||||||
|
0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
|
||||||
|
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
|
||||||
|
|
||||||
output_path = get_darts_rts_file(grid, level)
|
output_path = get_darts_rts_file(grid, level)
|
||||||
grid_gdf.to_parquet(output_path)
|
grid_gdf.to_parquet(output_path)
|
||||||
print(f"Saved RTS labels to {output_path}")
|
print(f"Saved RTS labels to {output_path}")
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,10 @@ Monthly, Winter, Summer & Yearly Aggregations (Names don't change):
|
||||||
- *_rest -> median
|
- *_rest -> median
|
||||||
- accum variables: sum
|
- accum variables: sum
|
||||||
|
|
||||||
|
Additionally:
|
||||||
|
- snow_cover_min [instant]: min of snowc_mean over month/winter/summer/year
|
||||||
|
- snow_cover_max [instant]: max of snowc_mean over month/winter/summer/year
|
||||||
|
|
||||||
Derived & (from monthly) Aggregated Winter Variables:
|
Derived & (from monthly) Aggregated Winter Variables:
|
||||||
- effective_snow_depth [instant]: (sde_mean * M + 1 - m).sum(M) / (m).sum(M),see also https://tc.copernicus.org/articles/11/989/2017/tc-11-989-2017.pdf
|
- effective_snow_depth [instant]: (sde_mean * M + 1 - m).sum(M) / (m).sum(M),see also https://tc.copernicus.org/articles/11/989/2017/tc-11-989-2017.pdf
|
||||||
|
|
||||||
|
|
@ -75,6 +79,7 @@ Date: June to October 2025
|
||||||
|
|
||||||
import cProfile
|
import cProfile
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cyclopts
|
import cyclopts
|
||||||
|
|
@ -88,6 +93,7 @@ import shapely.ops
|
||||||
import ultraplot as uplt
|
import ultraplot as uplt
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
import xdggs
|
import xdggs
|
||||||
|
import xvec
|
||||||
from rich import pretty, print, traceback
|
from rich import pretty, print, traceback
|
||||||
from rich.progress import track
|
from rich.progress import track
|
||||||
from shapely.geometry import LineString, Polygon
|
from shapely.geometry import LineString, Polygon
|
||||||
|
|
@ -95,6 +101,7 @@ from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice import codecs, grids, watermask
|
from entropice import codecs, grids, watermask
|
||||||
from entropice.paths import FIGURES_DIR, get_era5_stores
|
from entropice.paths import FIGURES_DIR, get_era5_stores
|
||||||
|
from entropice.xvec import to_xvec
|
||||||
|
|
||||||
traceback.install(show_locals=True, suppress=[cyclopts, xr, pd, cProfile])
|
traceback.install(show_locals=True, suppress=[cyclopts, xr, pd, cProfile])
|
||||||
pretty.install()
|
pretty.install()
|
||||||
|
|
@ -572,21 +579,38 @@ def enrich(n_workers: int = 10, monthly: bool = True, yearly: bool = True, daily
|
||||||
|
|
||||||
|
|
||||||
@cli.command
|
@cli.command
|
||||||
def viz(agg: Literal["daily", "monthly", "yearly", "summer", "winter"]):
|
def viz(
|
||||||
|
grid: Literal["hex", "healpix"] | None = None,
|
||||||
|
level: int | None = None,
|
||||||
|
agg: Literal["daily", "monthly", "yearly", "summer", "winter", "seasonal", "shoulder"] = "monthly",
|
||||||
|
high_qual: bool = False,
|
||||||
|
):
|
||||||
"""Visualize a small overview of ERA5 variables for a given aggregation.
|
"""Visualize a small overview of ERA5 variables for a given aggregation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agg (Literal["daily", "monthly", "yearly", "summer", "winter"]):
|
grid (Literal["hex", "healpix"], optional): Grid type for spatial representation.
|
||||||
|
If provided along with level, the ERA5 data will be decoded onto the specified grid.
|
||||||
|
level (int, optional): Level of the grid for spatial representation.
|
||||||
|
If provided along with grid, the ERA5 data will be decoded onto the specified grid.
|
||||||
|
agg (Literal["daily", "monthly", "yearly", "summer", "winter", "seasonal", "shoulder"], optional):
|
||||||
Aggregation identifier used to locate the appropriate ERA5 Zarr store via
|
Aggregation identifier used to locate the appropriate ERA5 Zarr store via
|
||||||
get_era5_stores. Determines which dataset is opened and visualized.
|
get_era5_stores. Determines which dataset is opened and visualized.
|
||||||
|
high_qual (bool, optional): Whether to use high quality plotting settings.
|
||||||
|
If True, the plot will be generated with higher resolution and quality settings.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> # produce and save an overview for monthly ERA5 data
|
>>> # produce and save an overview for monthly ERA5 data
|
||||||
>>> viz("monthly")
|
>>> viz("monthly")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
store = get_era5_stores(agg)
|
is_grid = grid is not None and level is not None
|
||||||
|
store = get_era5_stores(agg, grid, level)
|
||||||
ds = xr.open_zarr(store, consolidated=False).set_coords("spatial_ref")
|
ds = xr.open_zarr(store, consolidated=False).set_coords("spatial_ref")
|
||||||
|
if is_grid:
|
||||||
|
ds = xdggs.decode(ds)
|
||||||
|
# Get cell boundaries for plotting
|
||||||
|
ds = to_xvec(ds)
|
||||||
|
|
||||||
tis = [0, 1, -2, -1]
|
tis = [0, 1, -2, -1]
|
||||||
ts = [str(ds.time.isel(time=t).values)[:10] for t in tis]
|
ts = [str(ds.time.isel(time=t).values)[:10] for t in tis]
|
||||||
|
|
@ -594,16 +618,28 @@ def viz(agg: Literal["daily", "monthly", "yearly", "summer", "winter"]):
|
||||||
vunits = [ds[var].attrs.get("units", "") for var in ds.data_vars]
|
vunits = [ds[var].attrs.get("units", "") for var in ds.data_vars]
|
||||||
vlabels = [f"{name} [{unit}]" if unit else name for name, unit in zip(vnames, vunits)]
|
vlabels = [f"{name} [{unit}]" if unit else name for name, unit in zip(vnames, vunits)]
|
||||||
|
|
||||||
fig, axs = uplt.subplots(ncols=len(tis), nrows=len(vlabels), proj="npaeqd", share=0)
|
proj = uplt.Proj("npaeqd")
|
||||||
|
fig, axs = uplt.subplots(ncols=len(tis), nrows=len(vlabels), proj=proj, share=0)
|
||||||
axs.format(boundinglat=50, coast=True, toplabels=ts, leftlabels=vlabels)
|
axs.format(boundinglat=50, coast=True, toplabels=ts, leftlabels=vlabels)
|
||||||
for t in tis:
|
for t in tis:
|
||||||
for i, var in enumerate(ds.data_vars):
|
for i, var in enumerate(ds.data_vars):
|
||||||
subset = ds[var].isel(time=t).load()
|
subset = ds[var].isel(time=t)
|
||||||
m = axs[i, t].pcolormesh(subset, cmap="viridis")
|
if is_grid:
|
||||||
axs[i, t].colorbar(m, loc="ll", label=var)
|
subset.xvec.to_geodataframe().to_crs(proj.proj4_init).plot(ax=axs[i, t], column=var)
|
||||||
|
else:
|
||||||
|
if not high_qual:
|
||||||
|
subset = subset.isel(latitude=slice(None, None, 4), longitude=slice(None, None, 4))
|
||||||
|
m = axs[i, t].pcolormesh(subset, cmap="viridis")
|
||||||
|
axs[i, t].colorbar(m, loc="ll", label=var)
|
||||||
fig.format(suptitle="ERA5 Data")
|
fig.format(suptitle="ERA5 Data")
|
||||||
(FIGURES_DIR / "era5").mkdir(parents=True, exist_ok=True)
|
figdir = (FIGURES_DIR / "era5").resolve()
|
||||||
fig.savefig(FIGURES_DIR / "era5" / f"{agg}_overview_unaligned.png")
|
figdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
if high_qual:
|
||||||
|
print(f"Saving ERA5 {agg} overview figure to {figdir / f'{agg}_overview_unaligned.pdf'}.")
|
||||||
|
fig.savefig(figdir / f"{agg}_overview_unaligned.pdf")
|
||||||
|
else:
|
||||||
|
print(f"Saving ERA5 {agg} overview figure to {figdir / f'{agg}_overview_unaligned.png'}.")
|
||||||
|
fig.savefig(figdir / f"{agg}_overview_unaligned.png", dpi=100)
|
||||||
|
|
||||||
|
|
||||||
# ===========================
|
# ===========================
|
||||||
|
|
@ -640,7 +676,7 @@ def _check_geom(geobox: odc.geo.geobox.GeoBox, geom: odc.geo.Geometry) -> bool:
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Getting corrected geometries", log=False)
|
@stopwatch("Getting corrected geometries", log=False)
|
||||||
def _get_corrected_geoms(geom: Polygon, gbox: odc.geo.geobox.GeoBox) -> list[odc.geo.Geometry]:
|
def _get_corrected_geoms(inp: tuple[Polygon, odc.geo.geobox.GeoBox]) -> list[odc.geo.Geometry]:
|
||||||
"""Get corrected geometries for antimeridian-crossing polygons.
|
"""Get corrected geometries for antimeridian-crossing polygons.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -651,6 +687,7 @@ def _get_corrected_geoms(geom: Polygon, gbox: odc.geo.geobox.GeoBox) -> list[odc
|
||||||
list[odc.geo.Geometry]: List of corrected, georeferenced geometries.
|
list[odc.geo.Geometry]: List of corrected, georeferenced geometries.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
geom, gbox = inp
|
||||||
# cell.geometry is a shapely Polygon
|
# cell.geometry is a shapely Polygon
|
||||||
if not _crosses_antimeridian(geom):
|
if not _crosses_antimeridian(geom):
|
||||||
geoms = [geom]
|
geoms = [geom]
|
||||||
|
|
@ -659,43 +696,71 @@ def _get_corrected_geoms(geom: Polygon, gbox: odc.geo.geobox.GeoBox) -> list[odc
|
||||||
geoms = _split_antimeridian_cell(geom)
|
geoms = _split_antimeridian_cell(geom)
|
||||||
|
|
||||||
geoms = [odc.geo.Geometry(g, crs="epsg:4326") for g in geoms]
|
geoms = [odc.geo.Geometry(g, crs="epsg:4326") for g in geoms]
|
||||||
geoms = filter(lambda g: _check_geom(gbox, g), geoms)
|
geoms = list(filter(lambda g: _check_geom(gbox, g), geoms))
|
||||||
return geoms
|
return geoms
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Extracting cell data", log=False)
|
|
||||||
def extract_cell_data(ds: xr.Dataset, geoms: list[odc.geo.Geometry]) -> xr.Dataset | bool:
|
|
||||||
"""Extract ERA5 data for a specific grid cell geometry.
|
|
||||||
|
|
||||||
Extracts and spatially averages ERA5 data within the bounds of a grid cell.
|
|
||||||
Handles antimeridian-crossing cells by splitting them appropriately.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ds (xr.Dataset): An ERA5 dataset.
|
|
||||||
geoms (list[odc.geo.Geometry]): List of (valid) geometries of the grid cell.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if len(geoms) == 0:
|
|
||||||
return False
|
|
||||||
elif len(geoms) == 1:
|
|
||||||
return ds.odc.crop(geoms[0]).drop_vars("spatial_ref").mean(["latitude", "longitude"], skipna=True).compute()
|
|
||||||
else:
|
|
||||||
parts = [
|
|
||||||
ds.odc.crop(geom).drop_vars("spatial_ref").mean(["latitude", "longitude"], skipna=True) for geom in geoms
|
|
||||||
]
|
|
||||||
parts = [part for part in parts if part.latitude.size > 0 and part.longitude.size > 0]
|
|
||||||
if len(parts) == 0:
|
|
||||||
raise ValueError("No valid parts found for geometry. This should not happen!")
|
|
||||||
elif len(parts) == 1:
|
|
||||||
return parts[0].compute()
|
|
||||||
else:
|
|
||||||
return xr.concat(parts, dim="part").mean("part", skipna=True).compute()
|
|
||||||
|
|
||||||
|
|
||||||
def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
|
def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
|
||||||
return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude")
|
return ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)).sortby("longitude")
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Extracting cell data", log=False)
|
||||||
|
def _extract_cell_data(ds, geom):
|
||||||
|
cropped = ds.odc.crop(geom).drop_vars("spatial_ref")
|
||||||
|
with np.errstate(divide="ignore", invalid="ignore"):
|
||||||
|
cell_data = cropped.mean(["latitude", "longitude"])
|
||||||
|
return {var: cell_data[var].values for var in cell_data.data_vars}
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Extracting split cell data", log=False)
|
||||||
|
def _extract_split_cell_data(dss: xr.Dataset, geoms):
|
||||||
|
parts: list[xr.Dataset] = [ds.odc.crop(geom).drop_vars("spatial_ref") for ds, geom in zip(dss, geoms)]
|
||||||
|
partial_counts = [part.notnull().sum(dim=["latitude", "longitude"]) for part in parts]
|
||||||
|
with np.errstate(divide="ignore", invalid="ignore"):
|
||||||
|
partial_means = [part.sum(["latitude", "longitude"]) for part in parts]
|
||||||
|
n = xr.concat(partial_counts, dim="part").sum("part")
|
||||||
|
cell_data = xr.concat(partial_means, dim="part").sum("part") / n
|
||||||
|
return {var: cell_data[var].values for var in cell_data.data_vars}
|
||||||
|
|
||||||
|
|
||||||
|
@stopwatch("Aligning data")
|
||||||
|
def _align_data(cell_geometries: list[list[odc.geo.Geometry]], unaligned: xr.Dataset) -> dict[str, np.ndarray]:
|
||||||
|
# Persist the dataset, as all the operations here MUST NOT be lazy
|
||||||
|
unaligned = unaligned.load()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
var: np.full((len(cell_geometries), len(unaligned.time)), np.nan, dtype=np.float32)
|
||||||
|
for var in unaligned.data_vars
|
||||||
|
}
|
||||||
|
with ProcessPoolExecutor(max_workers=10) as executor:
|
||||||
|
futures = {}
|
||||||
|
for i, geoms in track(
|
||||||
|
enumerate(cell_geometries), total=len(cell_geometries), description="Submitting cell extraction tasks..."
|
||||||
|
):
|
||||||
|
if len(geoms) == 0:
|
||||||
|
continue
|
||||||
|
elif len(geoms) == 1:
|
||||||
|
geom = geoms[0]
|
||||||
|
# Reduce the amount of data needed to be sent to the worker
|
||||||
|
# Since we dont mask the data, only isel operations are done here
|
||||||
|
# Thus, to properly extract the subset, another masked crop needs to be done in the worker
|
||||||
|
unaligned_subset = unaligned.odc.crop(geom, apply_mask=False)
|
||||||
|
futures[executor.submit(_extract_cell_data, unaligned_subset, geom)] = i
|
||||||
|
else:
|
||||||
|
# Same as above but for multiple parts
|
||||||
|
unaligned_subsets = [unaligned.odc.crop(geom, apply_mask=False) for geom in geoms]
|
||||||
|
futures[executor.submit(_extract_split_cell_data, unaligned_subsets, geoms)] = i
|
||||||
|
|
||||||
|
for future in track(
|
||||||
|
as_completed(futures), total=len(futures), description="Spatially aggregating ERA5 data..."
|
||||||
|
):
|
||||||
|
i = futures[future]
|
||||||
|
cell_data = future.result()
|
||||||
|
for var in unaligned.data_vars:
|
||||||
|
data[var][i, :] = cell_data[var]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
@stopwatch("Creating aligned dataset", log=False)
|
@stopwatch("Creating aligned dataset", log=False)
|
||||||
def _create_aligned(
|
def _create_aligned(
|
||||||
ds: xr.Dataset, data: dict[str, np.ndarray], grid: Literal["hex", "healpix"], level: int
|
ds: xr.Dataset, data: dict[str, np.ndarray], grid: Literal["hex", "healpix"], level: int
|
||||||
|
|
@ -737,48 +802,38 @@ def spatial_agg(
|
||||||
level (int): Grid resolution level.
|
level (int): Grid resolution level.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
grid_gdf = grids.open(grid, level)
|
with stopwatch(f"Loading {grid} grid at level {level}"):
|
||||||
# ? Mask out water, since we don't want to aggregate over oceans
|
grid_gdf = grids.open(grid, level)
|
||||||
grid_gdf = watermask.clip_grid(grid_gdf)
|
# ? Mask out water, since we don't want to aggregate over oceans
|
||||||
grid_gdf = grid_gdf.to_crs("epsg:4326")
|
grid_gdf = watermask.clip_grid(grid_gdf)
|
||||||
|
grid_gdf = grid_gdf.to_crs("epsg:4326")
|
||||||
|
|
||||||
# Precompute the geometries to clip later
|
for agg in ["yearly", "seasonal", "shoulder"]:
|
||||||
daily_store = get_era5_stores("daily")
|
|
||||||
daily_unaligned = xr.open_zarr(daily_store, consolidated=False).set_coords("spatial_ref")
|
|
||||||
assert {"latitude", "longitude", "time"} == set(daily_unaligned.dims)
|
|
||||||
assert daily_unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {daily_unaligned.odc.crs}"
|
|
||||||
daily_unaligned = _correct_longs(daily_unaligned)
|
|
||||||
|
|
||||||
cell_geometries = [_get_corrected_geoms(row.geometry, daily_unaligned.odc.geobox) for _, row in grid_gdf.iterrows()]
|
|
||||||
|
|
||||||
for agg in ["summer", "winter", "yearly"]:
|
|
||||||
unaligned_store = get_era5_stores(agg)
|
unaligned_store = get_era5_stores(agg)
|
||||||
with stopwatch(f"Loading {agg} ERA5 data"):
|
with stopwatch(f"Loading {agg} ERA5 data"):
|
||||||
unaligned = xr.open_zarr(unaligned_store, consolidated=False).set_coords("spatial_ref")
|
unaligned = xr.open_zarr(unaligned_store, consolidated=False).set_coords("spatial_ref").load()
|
||||||
assert {"latitude", "longitude", "time"} == set(unaligned.dims)
|
assert {"latitude", "longitude", "time"} == set(unaligned.dims)
|
||||||
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
|
assert unaligned.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {unaligned.odc.crs}"
|
||||||
unaligned = _correct_longs(unaligned)
|
unaligned = _correct_longs(unaligned)
|
||||||
|
|
||||||
data = {
|
with stopwatch("Precomputing cell geometries"):
|
||||||
var: np.full((len(grid_gdf), len(unaligned.time)), np.nan, dtype=np.float32) for var in unaligned.data_vars
|
with ProcessPoolExecutor(max_workers=20) as executor:
|
||||||
}
|
cell_geometries = list(
|
||||||
for i, geoms in track(
|
executor.map(
|
||||||
enumerate(cell_geometries),
|
_get_corrected_geoms,
|
||||||
total=len(grid_gdf),
|
[(row.geometry, unaligned.odc.geobox) for _, row in grid_gdf.iterrows()],
|
||||||
description=f"Spatially aggregating {agg} ERA5 data...",
|
)
|
||||||
):
|
)
|
||||||
if len(geoms) == 0:
|
|
||||||
print(f"Warning: No valid geometry for cell {grid_gdf.iloc[i].cell_id}.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
cell_data = extract_cell_data(unaligned, geoms)
|
data = _align_data(cell_geometries, unaligned)
|
||||||
for var in unaligned.data_vars:
|
|
||||||
data[var][i, :] = cell_data[var].values
|
|
||||||
|
|
||||||
aggregated = _create_aligned(unaligned, data, grid, level)
|
aggregated = _create_aligned(unaligned, data, grid, level)
|
||||||
store = get_era5_stores(agg, grid, level)
|
store = get_era5_stores(agg, grid, level)
|
||||||
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
with stopwatch(f"Saving spatially aggregated {agg} ERA5 data to {store}"):
|
||||||
|
aggregated.to_zarr(store, mode="w", consolidated=False, encoding=codecs.from_ds(aggregated))
|
||||||
print(f"Finished spatial matching for {agg} data.")
|
print(f"Finished spatial matching for {agg} data.")
|
||||||
|
|
||||||
|
print("### Stopwatch Summary ###")
|
||||||
stopwatch.summary()
|
stopwatch.summary()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
# ruff: noqa: D103
|
# ruff: noqa: D103
|
||||||
"""Paths for entropice data storage."""
|
"""Paths for entropice data storage."""
|
||||||
|
|
||||||
|
import datetime
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data")).resolve() / "entropice"
|
DATA_DIR = Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None)).resolve() / "entropice"
|
||||||
|
DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcoding for FAST cluster
|
||||||
|
|
||||||
GRIDS_DIR = DATA_DIR / "grids"
|
GRIDS_DIR = DATA_DIR / "grids"
|
||||||
FIGURES_DIR = Path("figures")
|
FIGURES_DIR = Path("figures")
|
||||||
|
|
@ -13,6 +15,8 @@ DARTS_DIR = DATA_DIR / "darts"
|
||||||
ERA5_DIR = DATA_DIR / "era5"
|
ERA5_DIR = DATA_DIR / "era5"
|
||||||
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
||||||
WATERMASK_DIR = DATA_DIR / "watermask"
|
WATERMASK_DIR = DATA_DIR / "watermask"
|
||||||
|
TRAINING_DIR = DATA_DIR / "training"
|
||||||
|
RESULTS_DIR = DATA_DIR / "results"
|
||||||
|
|
||||||
GRIDS_DIR.mkdir(parents=True, exist_ok=True)
|
GRIDS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -20,9 +24,10 @@ DARTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
ERA5_DIR.mkdir(parents=True, exist_ok=True)
|
ERA5_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
|
WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
TRAINING_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
||||||
watermask_file = DATA_DIR / "simplified_water_polygons.shp"
|
|
||||||
|
|
||||||
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
dartsl2_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
|
||||||
dartsl2_cov_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
dartsl2_cov_file = DARTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
|
||||||
|
|
@ -46,7 +51,7 @@ def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
|
|
||||||
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
rtsfile = DARTS_DIR / f"{gridname}_rts.parquet"
|
rtsfile = DARTS_DIR / f"{gridname}_darts.parquet"
|
||||||
return rtsfile
|
return rtsfile
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -74,3 +79,18 @@ def get_era5_stores(
|
||||||
gridname = _get_gridname(grid, level)
|
gridname = _get_gridname(grid, level)
|
||||||
aligned_path = ERA5_DIR / f"{gridname}_{agg}_climate.zarr"
|
aligned_path = ERA5_DIR / f"{gridname}_{agg}_climate.zarr"
|
||||||
return aligned_path
|
return aligned_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
|
gridname = _get_gridname(grid, level)
|
||||||
|
dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet"
|
||||||
|
return dataset_file
|
||||||
|
|
||||||
|
|
||||||
|
def get_cv_results_file(name: str, grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||||
|
gridname = _get_gridname(grid, level)
|
||||||
|
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}"
|
||||||
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
results_file = results_dir / "search_results.parquet"
|
||||||
|
return results_file
|
||||||
|
|
|
||||||
295
src/entropice/training.py
Normal file
295
src/entropice/training.py
Normal file
|
|
@ -0,0 +1,295 @@
|
||||||
|
# ruff: noqa: N806
|
||||||
|
"""Training dataset preparation and model training."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import cyclopts
|
||||||
|
import geopandas as gpd
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
import torch
|
||||||
|
import xarray as xr
|
||||||
|
from entropy import ESPAClassifier
|
||||||
|
from rich import pretty, traceback
|
||||||
|
from scipy.stats import loguniform, randint
|
||||||
|
from sklearn import set_config
|
||||||
|
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
|
||||||
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
|
from entropice.paths import (
|
||||||
|
get_cv_results_file,
|
||||||
|
get_darts_rts_file,
|
||||||
|
get_embeddings_store,
|
||||||
|
get_era5_stores,
|
||||||
|
get_train_dataset_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
traceback.install()
|
||||||
|
pretty.install()
|
||||||
|
|
||||||
|
set_config(array_api_dispatch=True)
|
||||||
|
|
||||||
|
sns.set_theme("talk", "whitegrid")
|
||||||
|
|
||||||
|
cli = cyclopts.App()
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def prepare_dataset(grid: Literal["hex", "healpix"], level: int):
|
||||||
|
"""Prepare training dataset by combining DARTS RTS labels, ERA5 data, and embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||||
|
level (int): The grid level to use.
|
||||||
|
|
||||||
|
"""
|
||||||
|
rts = gpd.read_parquet(get_darts_rts_file(grid=grid, level=level))
|
||||||
|
# Filter to coverage
|
||||||
|
rts = rts[rts["darts_has_coverage"]]
|
||||||
|
# Convert hex cell_id to int
|
||||||
|
if grid == "hex":
|
||||||
|
rts["cell_id"] = rts["cell_id"].apply(lambda x: int(x, 16))
|
||||||
|
|
||||||
|
# Get era5 data
|
||||||
|
era5_store = get_era5_stores("yearly", grid=grid, level=level)
|
||||||
|
era5 = xr.open_zarr(era5_store, consolidated=False)
|
||||||
|
era5 = era5.sel(cell_ids=rts["cell_id"].values)
|
||||||
|
|
||||||
|
era5_df = []
|
||||||
|
for var in era5.data_vars:
|
||||||
|
df = era5[var].drop_vars("spatial_ref").to_dataframe()
|
||||||
|
df["year"] = df.index.get_level_values("time").year
|
||||||
|
df = (
|
||||||
|
df.pivot_table(index="cell_ids", columns="year", values=var)
|
||||||
|
.rename(columns=lambda x: f"{var}_{x}")
|
||||||
|
.rename_axis(None, axis=1)
|
||||||
|
)
|
||||||
|
era5_df.append(df)
|
||||||
|
era5_df = pd.concat(era5_df, axis=1)
|
||||||
|
|
||||||
|
# TODO: season and shoulder data
|
||||||
|
|
||||||
|
# Get embeddings data
|
||||||
|
embs_store = get_embeddings_store(grid=grid, level=level)
|
||||||
|
embeddings = xr.open_zarr(embs_store, consolidated=False).__xarray_dataarray_variable__
|
||||||
|
embeddings = embeddings.sel(cell=rts["cell_id"].values)
|
||||||
|
|
||||||
|
embeddings_df = embeddings.to_dataframe(name="value")
|
||||||
|
embeddings_df = embeddings_df.pivot_table(index="cell", columns=["year", "agg", "band"], values="value")
|
||||||
|
embeddings_df.columns = [f"{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||||
|
|
||||||
|
# Combine datasets by cell id / cell
|
||||||
|
dataset = rts.set_index("cell_id").join(era5_df).join(embeddings_df)
|
||||||
|
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||||
|
|
||||||
|
dataset_file = get_train_dataset_file(grid=grid, level=level)
|
||||||
|
dataset.reset_index().to_parquet(dataset_file)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def random_cv(grid: Literal["hex", "healpix"], level: int):
|
||||||
|
"""Perform random cross-validation on the training dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||||
|
level (int): The grid level to use.
|
||||||
|
|
||||||
|
"""
|
||||||
|
data = get_train_dataset_file(grid=grid, level=level)
|
||||||
|
data = gpd.read_parquet(data)
|
||||||
|
|
||||||
|
cols_to_drop = ["cell_id", "geometry", "darts_has_rts"]
|
||||||
|
cols_to_drop += [col for col in data.columns if col.startswith("darts_")]
|
||||||
|
X_data = data.drop(columns=cols_to_drop).dropna()
|
||||||
|
y_data = data.loc[X_data.index, "darts_has_rts"]
|
||||||
|
X = X_data.to_numpy(dtype="float32")
|
||||||
|
y = y_data.to_numpy(dtype="int8")
|
||||||
|
X, y = torch.asarray(X, device=0), torch.asarray(y, device=0)
|
||||||
|
print(f"{X.shape=}, {y.shape=}")
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
|
||||||
|
print(f"{X_train.shape=}, {X_test.shape=}, {y_train.shape=}, {y_test.shape=}")
|
||||||
|
|
||||||
|
param_grid = {
|
||||||
|
"eps_cl": loguniform(1e-3, 1e7),
|
||||||
|
"eps_e": loguniform(1e-3, 1e7),
|
||||||
|
"initial_K": randint(20, 400),
|
||||||
|
}
|
||||||
|
|
||||||
|
clf = ESPAClassifier(20, 0.1, 0.1, random_state=42)
|
||||||
|
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||||
|
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"] # "roc_auc" does not work on GPU
|
||||||
|
search = RandomizedSearchCV(
|
||||||
|
clf,
|
||||||
|
param_grid,
|
||||||
|
n_iter=20,
|
||||||
|
n_jobs=24,
|
||||||
|
cv=cv,
|
||||||
|
random_state=42,
|
||||||
|
verbose=10,
|
||||||
|
scoring=metrics,
|
||||||
|
refit="f1",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Starting RandomizedSearchCV with {search.n_iter} candidates...")
|
||||||
|
with stopwatch(f"RandomizedSearchCV fitting for {search.n_iter} candidates"):
|
||||||
|
search.fit(X_train, y_train, max_iter=100)
|
||||||
|
|
||||||
|
print("Best parameters combination found:")
|
||||||
|
best_parameters = search.best_estimator_.get_params()
|
||||||
|
for param_name in sorted(param_grid.keys()):
|
||||||
|
print(f"{param_name}: {best_parameters[param_name]}")
|
||||||
|
|
||||||
|
test_accuracy = search.score(X_test, y_test)
|
||||||
|
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
|
||||||
|
print(f"Accuracy on test set: {test_accuracy:.3f}")
|
||||||
|
|
||||||
|
# Store the search results
|
||||||
|
results = pd.DataFrame(search.cv_results_)
|
||||||
|
# Parse the params into individual columns
|
||||||
|
params = pd.json_normalize(results["params"])
|
||||||
|
# Concatenate the params columns with the original DataFrame
|
||||||
|
results = pd.concat([results.drop(columns=["params"]), params], axis=1)
|
||||||
|
results["grid"] = grid
|
||||||
|
results["level"] = level
|
||||||
|
results_file = get_cv_results_file("random_search", grid=grid, level=level)
|
||||||
|
print(f"Storing CV results to {results_file}")
|
||||||
|
results.to_parquet(results_file)
|
||||||
|
|
||||||
|
stopwatch.summary()
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
plot_random_cv_results(results_file)
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_k_binned(
|
||||||
|
results: pd.DataFrame, target: str, *, vmin_percentile: float | None = None, vmax_percentile: float | None = None
|
||||||
|
):
|
||||||
|
assert vmin_percentile is None or vmax_percentile is None, (
|
||||||
|
"Only one of vmin_percentile or vmax_percentile can be set."
|
||||||
|
)
|
||||||
|
assert "initial_K_binned" in results.columns, "initial_K_binned column not found in results."
|
||||||
|
assert target in results.columns, f"{target} column not found in results."
|
||||||
|
assert "eps_e" in results.columns, "eps_e column not found in results."
|
||||||
|
assert "eps_cl" in results.columns, "eps_cl column not found in results."
|
||||||
|
|
||||||
|
# add a colorbar instead of the sampled legend
|
||||||
|
cmap = sns.color_palette("ch:", as_cmap=True)
|
||||||
|
# sufisticated normalization
|
||||||
|
if vmin_percentile is not None:
|
||||||
|
vmin = np.percentile(results[target], vmin_percentile)
|
||||||
|
norm = mcolors.Normalize(vmin=vmin)
|
||||||
|
elif vmax_percentile is not None:
|
||||||
|
vmax = np.percentile(results[target], vmax_percentile)
|
||||||
|
norm = mcolors.Normalize(vmax=vmax)
|
||||||
|
else:
|
||||||
|
norm = mcolors.Normalize()
|
||||||
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
||||||
|
|
||||||
|
# nice col-wrap based on columns
|
||||||
|
n_cols = results["initial_K_binned"].unique().size
|
||||||
|
col_wrap = 5 if n_cols % 5 == 0 else (4 if n_cols % 4 == 0 else 3)
|
||||||
|
|
||||||
|
scatter = sns.relplot(
|
||||||
|
data=results,
|
||||||
|
x="eps_e",
|
||||||
|
y="eps_cl",
|
||||||
|
hue=target,
|
||||||
|
hue_norm=sm.norm,
|
||||||
|
palette=cmap,
|
||||||
|
legend=False,
|
||||||
|
col="initial_K_binned",
|
||||||
|
col_wrap=col_wrap,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply log scale to all axes
|
||||||
|
for ax in scatter.axes.flat:
|
||||||
|
ax.set_xscale("log")
|
||||||
|
ax.set_yscale("log")
|
||||||
|
|
||||||
|
# Tight layout
|
||||||
|
scatter.figure.tight_layout()
|
||||||
|
|
||||||
|
# Add a shared colorbar at the bottom
|
||||||
|
scatter.figure.subplots_adjust(bottom=0.15) # Make room for the colorbar
|
||||||
|
cbar_ax = scatter.figure.add_axes([0.15, 0.05, 0.7, 0.02]) # [left, bottom, width, height]
|
||||||
|
cbar = scatter.figure.colorbar(sm, cax=cbar_ax, orientation="horizontal")
|
||||||
|
cbar.set_label(target)
|
||||||
|
|
||||||
|
return scatter
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_eps_binned(results: pd.DataFrame, target: str, metric: str):
|
||||||
|
assert "initial_K" in results.columns, "initial_K column not found in results."
|
||||||
|
assert metric in results.columns, f"{metric} not found in results."
|
||||||
|
|
||||||
|
if target == "eps_cl":
|
||||||
|
hue = "eps_cl"
|
||||||
|
col = "eps_e_binned"
|
||||||
|
elif target == "eps_e":
|
||||||
|
hue = "eps_e"
|
||||||
|
col = "eps_cl_binned"
|
||||||
|
assert hue in results.columns, f"{hue} column not found in results."
|
||||||
|
assert col in results.columns, f"{col} column not found in results."
|
||||||
|
|
||||||
|
return sns.relplot(results, x="initial_K", y=metric, hue=hue, col=col, col_wrap=5, hue_norm=mcolors.LogNorm())
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def plot_random_cv_results(file: Path):
|
||||||
|
"""Plot analysis of the results from the RandomCVSearch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (Path): The file of the results.
|
||||||
|
|
||||||
|
"""
|
||||||
|
print(f"Plotting random CV results from {file}...")
|
||||||
|
results = pd.read_parquet(file)
|
||||||
|
# Bin the initial_K into 40er bins
|
||||||
|
results["initial_K_binned"] = pd.cut(results["initial_K"], bins=range(20, 401, 40), right=False)
|
||||||
|
# Bin the eps_cl and eps_e into logarithmic bins
|
||||||
|
eps_cl_bins = np.logspace(-3, 7, num=10)
|
||||||
|
eps_e_bins = np.logspace(-3, 7, num=10)
|
||||||
|
results["eps_cl_binned"] = pd.cut(results["eps_cl"], bins=eps_cl_bins)
|
||||||
|
results["eps_e_binned"] = pd.cut(results["eps_e"], bins=eps_e_bins)
|
||||||
|
|
||||||
|
figdir = file.parent
|
||||||
|
|
||||||
|
# K-Plots
|
||||||
|
metrics = ["accuracy", "recall", "precision", "f1", "jaccard"]
|
||||||
|
for metric in metrics:
|
||||||
|
_plot_k_binned(
|
||||||
|
results,
|
||||||
|
f"mean_test_{metric}",
|
||||||
|
vmin_percentile=50,
|
||||||
|
).figure.savefig(figdir / f"params3d-mean_{metric}.pdf")
|
||||||
|
_plot_k_binned(
|
||||||
|
results,
|
||||||
|
f"std_test_{metric}",
|
||||||
|
vmax_percentile=50,
|
||||||
|
).figure.savefig(figdir / f"params3d-std_{metric}.pdf")
|
||||||
|
_plot_k_binned(results, f"mean_test_{metric}").figure.savefig(figdir / f"params3d-mean_{metric}-noperc.pdf")
|
||||||
|
_plot_k_binned(results, f"std_test_{metric}").figure.savefig(figdir / f"params3d-std_{metric}-noperc.pdf")
|
||||||
|
|
||||||
|
# eps-Plots
|
||||||
|
_plot_eps_binned(
|
||||||
|
results,
|
||||||
|
"eps_cl",
|
||||||
|
f"mean_test_{metric}",
|
||||||
|
).figure.savefig(figdir / f"k-eps_cl-mean_{metric}.pdf")
|
||||||
|
_plot_eps_binned(
|
||||||
|
results,
|
||||||
|
"eps_e",
|
||||||
|
f"mean_test_{metric}",
|
||||||
|
).figure.savefig(figdir / f"k-eps_e-mean_{metric}.pdf")
|
||||||
|
|
||||||
|
# Close all figures
|
||||||
|
plt.close("all")
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
|
|
@ -27,6 +27,6 @@ def clip_grid(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
watermask = open()
|
watermask = open()
|
||||||
watermask = watermask.to_crs("EPSG:3413")
|
watermask = watermask.to_crs(gdf.crs)
|
||||||
gdf = gdf.overlay(watermask, how="difference")
|
gdf = gdf.overlay(watermask, how="difference")
|
||||||
return gdf
|
return gdf
|
||||||
|
|
|
||||||
13
src/entropice/xvec.py
Normal file
13
src/entropice/xvec.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
import xarray as xr
|
||||||
|
import xdggs
|
||||||
|
import xvec
|
||||||
|
|
||||||
|
|
||||||
|
def to_xvec(ds: xr.Dataset) -> xr.Dataset:
|
||||||
|
ds["geometry"] = xdggs.decode(ds.dggs.cell_boundaries())
|
||||||
|
cell_ids = ds.cell_ids.values
|
||||||
|
ds = ds.set_index(cell_ids="geometry").rename_dims({"cell_ids": "geometry"}).rename_vars({"cell_ids": "geometry"})
|
||||||
|
ds = ds.xvec.set_geom_indexes("geometry", crs="epsg:4326")
|
||||||
|
ds["cell_ids"] = ("geometry", cell_ids)
|
||||||
|
ds = ds.set_coords("cell_ids")
|
||||||
|
return ds
|
||||||
65
uv.lock
generated
65
uv.lock
generated
|
|
@ -136,6 +136,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl", hash = "sha256:944b82ef1dd17b8ff0fb6d1f199f613caf9111338e6e2857da478f6e73770cb8", size = 220671 },
|
{ url = "https://files.pythonhosted.org/packages/2b/f0/09a30ca0551af20c7cefa7464b7ccb6f5407a550b83c4dcb15c410814849/anywidget-0.9.18-py3-none-any.whl", hash = "sha256:944b82ef1dd17b8ff0fb6d1f199f613caf9111338e6e2857da478f6e73770cb8", size = 220671 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "appdirs"
|
||||||
|
version = "1.4.4"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d7/d8/05696357e0311f5b5c316d7b95f46c669dd9c15aaeecbb48c7d0aeb88c40/appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41", size = 13470 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128", size = 9566 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "appnope"
|
name = "appnope"
|
||||||
version = "0.1.4"
|
version = "0.1.4"
|
||||||
|
|
@ -145,6 +154,27 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321 },
|
{ url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "array-api-compat"
|
||||||
|
version = "1.12.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/8d/bd/9fa5c7c5621698d5632cc852a79fbbdc28024462c9396698e5fdcb395f37/array_api_compat-1.12.0.tar.gz", hash = "sha256:585bc615f650de53ac24b7c012baecfcdd810f50df3573be47e6dd9fa20df974", size = 99883 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl", hash = "sha256:a0b4795b6944a9507fde54679f9350e2ad2b1e2acf4a2408a098cdc27f890a8b", size = 58156 },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "array-api-extra"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "array-api-compat" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/cb/be/387d596e0ed6d191988d7b61bb0e252bb8023965485218a33bd1a8ccc72a/array_api_extra-0.9.0.tar.gz", hash = "sha256:2d49e38394f5a96caef17b80964d228636e6fdf1dada33697ab8ec3f5364eb77", size = 81579 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8c/17/8a7503c9dc3f8cc2868c632291e5d224822b05ae62f1279c529c459368d2/array_api_extra-0.9.0-py3-none-any.whl", hash = "sha256:36b34e29380b678007f151511be950cb2ea199606fe4a7ad466efc5044ea9e44", size = 48234 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "arro3-compute"
|
name = "arro3-compute"
|
||||||
version = "0.6.5"
|
version = "0.6.5"
|
||||||
|
|
@ -1217,10 +1247,11 @@ dependencies = [
|
||||||
{ name = "distributed" },
|
{ name = "distributed" },
|
||||||
{ name = "earthengine-api" },
|
{ name = "earthengine-api" },
|
||||||
{ name = "eemont" },
|
{ name = "eemont" },
|
||||||
{ name = "entropyc" },
|
{ name = "entropy" },
|
||||||
{ name = "flox" },
|
{ name = "flox" },
|
||||||
{ name = "folium" },
|
{ name = "folium" },
|
||||||
{ name = "geemap" },
|
{ name = "geemap" },
|
||||||
|
{ name = "geocube" },
|
||||||
{ name = "geopandas" },
|
{ name = "geopandas" },
|
||||||
{ name = "h3" },
|
{ name = "h3" },
|
||||||
{ name = "h5netcdf" },
|
{ name = "h5netcdf" },
|
||||||
|
|
@ -1264,10 +1295,11 @@ requires-dist = [
|
||||||
{ name = "distributed", specifier = ">=2025.5.1" },
|
{ name = "distributed", specifier = ">=2025.5.1" },
|
||||||
{ name = "earthengine-api", specifier = ">=1.6.9" },
|
{ name = "earthengine-api", specifier = ">=1.6.9" },
|
||||||
{ name = "eemont", specifier = ">=2025.7.1" },
|
{ name = "eemont", specifier = ">=2025.7.1" },
|
||||||
{ name = "entropyc", git = "ssh://git@github.com/AlbertEMC2Stein/entropyc?branch=refactor%2Ftobi" },
|
{ name = "entropy", git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" },
|
||||||
{ name = "flox", specifier = ">=0.10.4" },
|
{ name = "flox", specifier = ">=0.10.4" },
|
||||||
{ name = "folium", specifier = ">=0.19.7" },
|
{ name = "folium", specifier = ">=0.19.7" },
|
||||||
{ name = "geemap", specifier = ">=0.36.3" },
|
{ name = "geemap", specifier = ">=0.36.3" },
|
||||||
|
{ name = "geocube", specifier = ">=0.7.1,<0.8" },
|
||||||
{ name = "geopandas", specifier = ">=1.1.0" },
|
{ name = "geopandas", specifier = ">=1.1.0" },
|
||||||
{ name = "h3", specifier = ">=4.2.2" },
|
{ name = "h3", specifier = ">=4.2.2" },
|
||||||
{ name = "h5netcdf", specifier = ">=1.6.4" },
|
{ name = "h5netcdf", specifier = ">=1.6.4" },
|
||||||
|
|
@ -1300,11 +1332,13 @@ requires-dist = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "entropyc"
|
name = "entropy"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
source = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc?branch=refactor%2Ftobi#22a191d194a76b6c182481acb2af1bde3f60b49e" }
|
source = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "numpy" },
|
{ name = "array-api-compat" },
|
||||||
|
{ name = "array-api-extra" },
|
||||||
|
{ name = "scikit-learn" },
|
||||||
{ name = "scipy" },
|
{ name = "scipy" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -1569,6 +1603,27 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/4f/6b/13166c909ad2f2d76b929a4227c952630ebaf0d729f6317eb09cbceccbab/geocoder-1.38.1-py2.py3-none-any.whl", hash = "sha256:a733e1dfbce3f4e1a526cac03aadcedb8ed1239cf55bd7f3a23c60075121a834", size = 98590 },
|
{ url = "https://files.pythonhosted.org/packages/4f/6b/13166c909ad2f2d76b929a4227c952630ebaf0d729f6317eb09cbceccbab/geocoder-1.38.1-py2.py3-none-any.whl", hash = "sha256:a733e1dfbce3f4e1a526cac03aadcedb8ed1239cf55bd7f3a23c60075121a834", size = 98590 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "geocube"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "appdirs" },
|
||||||
|
{ name = "click" },
|
||||||
|
{ name = "geopandas" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "odc-geo" },
|
||||||
|
{ name = "pyproj" },
|
||||||
|
{ name = "rasterio" },
|
||||||
|
{ name = "rioxarray" },
|
||||||
|
{ name = "scipy" },
|
||||||
|
{ name = "xarray" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/6c/03/d39b7a372f2054ae374247c8e0130b8f23aee89b0624c8f04fa49b2c1199/geocube-0.7.1.tar.gz", hash = "sha256:5f0f4a2143b379434d81172ae8c9fb49c2c5ff2f9723864ed79d3947a68ea37f", size = 20528 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/31/c6/a9341239e2e2953537b9e90a46ebc59f2e122247a3fe22373cc37520fc44/geocube-0.7.1-py3-none-any.whl", hash = "sha256:661a12c0b2106f27477290b5f18e76eb5855c9c50cac7fd19028fde4babca628", size = 23107 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "geographiclib"
|
name = "geographiclib"
|
||||||
version = "2.1"
|
version = "2.1"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue