138 lines
5 KiB
Python
138 lines
5 KiB
Python
import dask.distributed as dd
|
|
import xarray as xr
|
|
from rich import print
|
|
|
|
import entropice.utils.codecs
|
|
from entropice.utils.paths import get_era5_stores
|
|
|
|
|
|
def print_info(daily_raw=None, show_vars: bool = True):
|
|
if daily_raw is None:
|
|
daily_store = get_era5_stores("daily")
|
|
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
|
print("=== Daily INFO ===")
|
|
print(f" Dims: {daily_raw.sizes}")
|
|
numchunks = 1
|
|
chunksizes = {}
|
|
approxchunksize = 4 # 4 Bytes = float32
|
|
for d, cs in daily_raw.chunksizes.items():
|
|
numchunks *= len(cs)
|
|
chunksizes[d] = max(cs)
|
|
approxchunksize *= max(cs)
|
|
approxchunksize /= 10e6 # MB
|
|
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
|
|
print(f" Encoding: {daily_raw.encoding}")
|
|
if show_vars:
|
|
print(" Variables:")
|
|
for var in daily_raw.data_vars:
|
|
da = daily_raw[var]
|
|
print(f" {var} Encoding:")
|
|
print(da.encoding)
|
|
print("")
|
|
|
|
|
|
def rechunk(use_shards: bool = False):
|
|
if use_shards:
|
|
# ! MEEEP: https://github.com/pydata/xarray/issues/10831
|
|
print("WARNING! Rechunking with shards. This is known to be broken in xarray/dask!")
|
|
with (
|
|
dd.LocalCluster(n_workers=8, threads_per_worker=5, memory_limit="20GB") as cluster,
|
|
dd.Client(cluster) as client,
|
|
):
|
|
print(f"Dashboard: {client.dashboard_link}")
|
|
daily_store = get_era5_stores("daily")
|
|
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked_sharded")
|
|
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
|
daily_raw = daily_raw.chunk(
|
|
{
|
|
"time": 120,
|
|
"latitude": -1, # Should be 337,
|
|
"longitude": -1, # Should be 3600
|
|
}
|
|
)
|
|
encoding = entropice.utils.codecs.from_ds(daily_raw, filter_existing=False)
|
|
for var in daily_raw.data_vars:
|
|
encoding[var]["chunks"] = (120, 337, 3600)
|
|
if use_shards:
|
|
encoding[var]["shards"] = (1200, 337, 3600)
|
|
print(encoding)
|
|
daily_raw.to_zarr(daily_store_rechunked, mode="w", consolidated=False, encoding=encoding)
|
|
|
|
|
|
def validate():
|
|
daily_store = get_era5_stores("daily")
|
|
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
|
|
|
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
|
|
daily_rechunked = xr.open_zarr(daily_store_rechunked, consolidated=False)
|
|
|
|
print("\n=== Comparing Datasets ===")
|
|
|
|
# Compare dimensions
|
|
if daily_raw.sizes != daily_rechunked.sizes:
|
|
print("❌ Dimensions differ:")
|
|
print(f" Original: {daily_raw.sizes}")
|
|
print(f" Rechunked: {daily_rechunked.sizes}")
|
|
else:
|
|
print("✅ Dimensions match")
|
|
|
|
# Compare variables
|
|
raw_vars = set(daily_raw.data_vars)
|
|
rechunked_vars = set(daily_rechunked.data_vars)
|
|
if raw_vars != rechunked_vars:
|
|
print("❌ Variables differ:")
|
|
print(f" Only in original: {raw_vars - rechunked_vars}")
|
|
print(f" Only in rechunked: {rechunked_vars - raw_vars}")
|
|
else:
|
|
print("✅ Variables match")
|
|
|
|
# Compare each variable
|
|
print("\n=== Variable Comparison ===")
|
|
all_equal = True
|
|
for var in raw_vars & rechunked_vars:
|
|
raw_var = daily_raw[var]
|
|
rechunked_var = daily_rechunked[var]
|
|
|
|
if raw_var.equals(rechunked_var):
|
|
print(f"✅ {var}: Equal")
|
|
else:
|
|
all_equal = False
|
|
print(f"❌ {var}: NOT Equal")
|
|
|
|
# Check if values are equal
|
|
try:
|
|
values_equal = raw_var.values.shape == rechunked_var.values.shape
|
|
if values_equal:
|
|
import numpy as np
|
|
|
|
values_equal = np.allclose(raw_var.values, rechunked_var.values, equal_nan=True)
|
|
|
|
if values_equal:
|
|
print(" → Values are numerically equal (likely metadata/encoding difference)")
|
|
else:
|
|
print(" → Values differ!")
|
|
print(f" Original shape: {raw_var.values.shape}")
|
|
print(f" Rechunked shape: {rechunked_var.values.shape}")
|
|
except Exception as e:
|
|
print(f" → Error comparing values: {e}")
|
|
|
|
# Check attributes
|
|
if raw_var.attrs != rechunked_var.attrs:
|
|
print(" → Attributes differ:")
|
|
print(f" Original: {raw_var.attrs}")
|
|
print(f" Rechunked: {rechunked_var.attrs}")
|
|
|
|
# Check encoding
|
|
if raw_var.encoding != rechunked_var.encoding:
|
|
print(" → Encoding differs:")
|
|
print(f" Original: {raw_var.encoding}")
|
|
print(f" Rechunked: {rechunked_var.encoding}")
|
|
|
|
if all_equal:
|
|
print("\n✅ Validation successful: All datasets are equal.")
|
|
else:
|
|
print("\n❌ Validation failed: Datasets have differences (see above).")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
validate()
|