Small fixes all over the place

This commit is contained in:
Tobias Hölzer 2026-01-08 20:00:09 +01:00
parent c92e856c55
commit 1495f71ac9
9 changed files with 3923 additions and 4084 deletions

View file

@ -1,12 +1,12 @@
import xarray as xr
import zarr
from rich import print
import dask.distributed as dd
import xarray as xr
from rich import print
from entropice.utils.paths import get_era5_stores
import entropice.utils.codecs
from entropice.utils.paths import get_era5_stores
def print_info(daily_raw = None, show_vars: bool = True):
def print_info(daily_raw=None, show_vars: bool = True):
if daily_raw is None:
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
@ -14,12 +14,12 @@ def print_info(daily_raw = None, show_vars: bool = True):
print(f" Dims: {daily_raw.sizes}")
numchunks = 1
chunksizes = {}
approxchunksize = 4 # 4 Bytes = float32
approxchunksize = 4 # 4 Bytes = float32
for d, cs in daily_raw.chunksizes.items():
numchunks *= len(cs)
chunksizes[d] = max(cs)
approxchunksize *= max(cs)
approxchunksize /= 10e6 # MB
approxchunksize /= 10e6 # MB
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
print(f" Encoding: {daily_raw.encoding}")
if show_vars:
@ -30,29 +30,109 @@ def print_info(daily_raw = None, show_vars: bool = True):
print(da.encoding)
print("")
def rechunk():
def rechunk(use_shards: bool = False):
if use_shards:
# ! MEEEP: https://github.com/pydata/xarray/issues/10831
print("WARNING! Rechunking with shards. This is known to be broken in xarray/dask!")
with (
dd.LocalCluster(n_workers=8, threads_per_worker=5, memory_limit="20GB") as cluster,
dd.Client(cluster) as client,
):
print(f"Dashboard: {client.dashboard_link}")
daily_store = get_era5_stores("daily")
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked_sharded")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
daily_raw = daily_raw.chunk(
{
"time": 120,
"latitude": -1, # Should be 337,
"longitude": -1, # Should be 3600
}
)
encoding = entropice.utils.codecs.from_ds(daily_raw, filter_existing=False)
for var in daily_raw.data_vars:
encoding[var]["chunks"] = (120, 337, 3600)
if use_shards:
encoding[var]["shards"] = (1200, 337, 3600)
print(encoding)
daily_raw.to_zarr(daily_store_rechunked, mode="w", consolidated=False, encoding=encoding)
def validate():
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
print_info(daily_raw, False)
daily_raw = daily_raw.chunk({
"time": 120,
"latitude": -1, # Should be 337,
"longitude": -1 # Should be 3600
})
print_info(daily_raw, False)
encoding = entropice.utils.codecs.from_ds(daily_raw)
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
daily_raw.to_zarr(daily_store_rechunked, mode="w", encoding=encoding, consolidated=False)
daily_rechunked = xr.open_zarr(daily_store_rechunked, consolidated=False)
print("\n=== Comparing Datasets ===")
# Compare dimensions
if daily_raw.sizes != daily_rechunked.sizes:
print("❌ Dimensions differ:")
print(f" Original: {daily_raw.sizes}")
print(f" Rechunked: {daily_rechunked.sizes}")
else:
print("✅ Dimensions match")
# Compare variables
raw_vars = set(daily_raw.data_vars)
rechunked_vars = set(daily_rechunked.data_vars)
if raw_vars != rechunked_vars:
print("❌ Variables differ:")
print(f" Only in original: {raw_vars - rechunked_vars}")
print(f" Only in rechunked: {rechunked_vars - raw_vars}")
else:
print("✅ Variables match")
# Compare each variable
print("\n=== Variable Comparison ===")
all_equal = True
for var in raw_vars & rechunked_vars:
raw_var = daily_raw[var]
rechunked_var = daily_rechunked[var]
if raw_var.equals(rechunked_var):
print(f"{var}: Equal")
else:
all_equal = False
print(f"{var}: NOT Equal")
# Check if values are equal
try:
values_equal = raw_var.values.shape == rechunked_var.values.shape
if values_equal:
import numpy as np
values_equal = np.allclose(raw_var.values, rechunked_var.values, equal_nan=True)
if values_equal:
print(" → Values are numerically equal (likely metadata/encoding difference)")
else:
print(" → Values differ!")
print(f" Original shape: {raw_var.values.shape}")
print(f" Rechunked shape: {rechunked_var.values.shape}")
except Exception as e:
print(f" → Error comparing values: {e}")
# Check attributes
if raw_var.attrs != rechunked_var.attrs:
print(" → Attributes differ:")
print(f" Original: {raw_var.attrs}")
print(f" Rechunked: {rechunked_var.attrs}")
# Check encoding
if raw_var.encoding != rechunked_var.encoding:
print(" → Encoding differs:")
print(f" Original: {raw_var.encoding}")
print(f" Rechunked: {rechunked_var.encoding}")
if all_equal:
print("\n✅ Validation successful: All datasets are equal.")
else:
print("\n❌ Validation failed: Datasets have differences (see above).")
if __name__ == "__main__":
with (
dd.LocalCluster(n_workers=1, threads_per_worker=10, memory_limit="100GB") as cluster,
dd.Client(cluster) as client,
):
print(client)
print(client.dashboard_link)
rechunk()
print("Done.")
validate()