Update docs, instructions and format code
This commit is contained in:
parent
fca232da91
commit
4260b492ab
29 changed files with 987 additions and 467 deletions
83
tests/l2dataset_validation.py
Normal file
83
tests/l2dataset_validation.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
from itertools import product
|
||||
|
||||
import xarray as xr
|
||||
from rich.progress import track
|
||||
|
||||
from entropice.utils.paths import (
|
||||
get_arcticdem_stores,
|
||||
get_embeddings_store,
|
||||
get_era5_stores,
|
||||
)
|
||||
from entropice.utils.types import Grid, L2SourceDataset
|
||||
|
||||
|
||||
def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool:
|
||||
"""Validate if the L2 dataset exists for the given grid and level.
|
||||
|
||||
Args:
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
l2ds (L2SourceDataset): The L2 source dataset to validate.
|
||||
|
||||
Returns:
|
||||
bool: True if the dataset exists and does not contain NaNs, False otherwise.
|
||||
|
||||
"""
|
||||
if l2ds == "ArcticDEM":
|
||||
store = get_arcticdem_stores(grid, level)
|
||||
elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly":
|
||||
agg = l2ds.split("-")[1]
|
||||
store = get_era5_stores(agg, grid, level) # type: ignore
|
||||
elif l2ds == "AlphaEarth":
|
||||
store = get_embeddings_store(grid, level)
|
||||
else:
|
||||
raise ValueError(f"Unsupported L2 source dataset: {l2ds}")
|
||||
|
||||
if not store.exists():
|
||||
print("\t Dataset store does not exist")
|
||||
return False
|
||||
|
||||
ds = xr.open_zarr(store, consolidated=False)
|
||||
has_nan = False
|
||||
for var in ds.data_vars:
|
||||
n_nans = ds[var].isnull().sum().compute().item()
|
||||
if n_nans > 0:
|
||||
print(f"\t Dataset contains {n_nans} NaNs in variable {var}")
|
||||
has_nan = True
|
||||
if has_nan:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
grid_levels: set[tuple[Grid, int]] = {
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
("hex", 6),
|
||||
("healpix", 6),
|
||||
("healpix", 7),
|
||||
("healpix", 8),
|
||||
("healpix", 9),
|
||||
("healpix", 10),
|
||||
}
|
||||
l2_source_datasets: list[L2SourceDataset] = [
|
||||
"ArcticDEM",
|
||||
"ERA5-shoulder",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-yearly",
|
||||
"AlphaEarth",
|
||||
]
|
||||
|
||||
for (grid, level), l2ds in track(
|
||||
product(grid_levels, l2_source_datasets),
|
||||
total=len(grid_levels) * len(l2_source_datasets),
|
||||
description="Validating L2 datasets...",
|
||||
):
|
||||
is_valid = validate_l2_dataset(grid, level, l2ds)
|
||||
status = "VALID" if is_valid else "INVALID"
|
||||
print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue