entropice/scripts/rechunk_zarr.py

138 lines
5 KiB
Python

import dask.distributed as dd
import xarray as xr
from rich import print
import entropice.utils.codecs
from entropice.utils.paths import get_era5_stores
def print_info(daily_raw=None, show_vars: bool = True):
if daily_raw is None:
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
print("=== Daily INFO ===")
print(f" Dims: {daily_raw.sizes}")
numchunks = 1
chunksizes = {}
approxchunksize = 4 # 4 Bytes = float32
for d, cs in daily_raw.chunksizes.items():
numchunks *= len(cs)
chunksizes[d] = max(cs)
approxchunksize *= max(cs)
approxchunksize /= 10e6 # MB
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
print(f" Encoding: {daily_raw.encoding}")
if show_vars:
print(" Variables:")
for var in daily_raw.data_vars:
da = daily_raw[var]
print(f" {var} Encoding:")
print(da.encoding)
print("")
def rechunk(use_shards: bool = False):
if use_shards:
# ! MEEEP: https://github.com/pydata/xarray/issues/10831
print("WARNING! Rechunking with shards. This is known to be broken in xarray/dask!")
with (
dd.LocalCluster(n_workers=8, threads_per_worker=5, memory_limit="20GB") as cluster,
dd.Client(cluster) as client,
):
print(f"Dashboard: {client.dashboard_link}")
daily_store = get_era5_stores("daily")
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked_sharded")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
daily_raw = daily_raw.chunk(
{
"time": 120,
"latitude": -1, # Should be 337,
"longitude": -1, # Should be 3600
}
)
encoding = entropice.utils.codecs.from_ds(daily_raw, filter_existing=False)
for var in daily_raw.data_vars:
encoding[var]["chunks"] = (120, 337, 3600)
if use_shards:
encoding[var]["shards"] = (1200, 337, 3600)
print(encoding)
daily_raw.to_zarr(daily_store_rechunked, mode="w", consolidated=False, encoding=encoding)
def validate():
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
daily_rechunked = xr.open_zarr(daily_store_rechunked, consolidated=False)
print("\n=== Comparing Datasets ===")
# Compare dimensions
if daily_raw.sizes != daily_rechunked.sizes:
print("❌ Dimensions differ:")
print(f" Original: {daily_raw.sizes}")
print(f" Rechunked: {daily_rechunked.sizes}")
else:
print("✅ Dimensions match")
# Compare variables
raw_vars = set(daily_raw.data_vars)
rechunked_vars = set(daily_rechunked.data_vars)
if raw_vars != rechunked_vars:
print("❌ Variables differ:")
print(f" Only in original: {raw_vars - rechunked_vars}")
print(f" Only in rechunked: {rechunked_vars - raw_vars}")
else:
print("✅ Variables match")
# Compare each variable
print("\n=== Variable Comparison ===")
all_equal = True
for var in raw_vars & rechunked_vars:
raw_var = daily_raw[var]
rechunked_var = daily_rechunked[var]
if raw_var.equals(rechunked_var):
print(f"{var}: Equal")
else:
all_equal = False
print(f"{var}: NOT Equal")
# Check if values are equal
try:
values_equal = raw_var.values.shape == rechunked_var.values.shape
if values_equal:
import numpy as np
values_equal = np.allclose(raw_var.values, rechunked_var.values, equal_nan=True)
if values_equal:
print(" → Values are numerically equal (likely metadata/encoding difference)")
else:
print(" → Values differ!")
print(f" Original shape: {raw_var.values.shape}")
print(f" Rechunked shape: {rechunked_var.values.shape}")
except Exception as e:
print(f" → Error comparing values: {e}")
# Check attributes
if raw_var.attrs != rechunked_var.attrs:
print(" → Attributes differ:")
print(f" Original: {raw_var.attrs}")
print(f" Rechunked: {rechunked_var.attrs}")
# Check encoding
if raw_var.encoding != rechunked_var.encoding:
print(" → Encoding differs:")
print(f" Original: {raw_var.encoding}")
print(f" Rechunked: {rechunked_var.encoding}")
if all_equal:
print("\n✅ Validation successful: All datasets are equal.")
else:
print("\n❌ Validation failed: Datasets have differences (see above).")
if __name__ == "__main__":
validate()