from itertools import product import xarray as xr from rich.progress import track from entropice.utils.paths import ( get_arcticdem_stores, get_embeddings_store, get_era5_stores, ) from entropice.utils.types import Grid, L2SourceDataset def validate_l2_dataset(grid: Grid, level: int, l2ds: L2SourceDataset) -> bool: """Validate if the L2 dataset exists for the given grid and level. Args: grid (Grid): The grid type to use. level (int): The grid level to use. l2ds (L2SourceDataset): The L2 source dataset to validate. Returns: bool: True if the dataset exists and does not contain NaNs, False otherwise. """ if l2ds == "ArcticDEM": store = get_arcticdem_stores(grid, level) elif l2ds == "ERA5-shoulder" or l2ds == "ERA5-seasonal" or l2ds == "ERA5-yearly": agg = l2ds.split("-")[1] store = get_era5_stores(agg, grid, level) # type: ignore elif l2ds == "AlphaEarth": store = get_embeddings_store(grid, level) else: raise ValueError(f"Unsupported L2 source dataset: {l2ds}") if not store.exists(): print("\t Dataset store does not exist") return False ds = xr.open_zarr(store, consolidated=False) has_nan = False for var in ds.data_vars: n_nans = ds[var].isnull().sum().compute().item() if n_nans > 0: print(f"\t Dataset contains {n_nans} NaNs in variable {var}") has_nan = True if has_nan: return False return True def main(): grid_levels: set[tuple[Grid, int]] = { ("hex", 3), ("hex", 4), ("hex", 5), ("hex", 6), ("healpix", 6), ("healpix", 7), ("healpix", 8), ("healpix", 9), ("healpix", 10), } l2_source_datasets: list[L2SourceDataset] = [ "ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth", ] for (grid, level), l2ds in track( product(grid_levels, l2_source_datasets), total=len(grid_levels) * len(l2_source_datasets), description="Validating L2 datasets...", ): is_valid = validate_l2_dataset(grid, level, l2ds) status = "VALID" if is_valid else "INVALID" print(f"L2 Dataset Validation - Grid: {grid}, Level: {level}, L2 Source: {l2ds} => {status}") if __name__ == "__main__": main()