Small fixes all over the place
This commit is contained in:
parent
c92e856c55
commit
1495f71ac9
9 changed files with 3923 additions and 4084 deletions
|
|
@ -12,7 +12,7 @@ dependencies = [
|
|||
"cartopy>=0.24.1",
|
||||
"cdsapi>=0.7.6",
|
||||
"cyclopts>=4.0.0",
|
||||
"dask>=2025.5.1",
|
||||
"dask>=2025.11.0",
|
||||
"distributed>=2025.5.1",
|
||||
"earthengine-api>=1.6.9",
|
||||
"eemont>=2025.7.1",
|
||||
|
|
@ -34,7 +34,6 @@ dependencies = [
|
|||
"odc-geo[all]>=0.4.10",
|
||||
"opt-einsum>=3.4.0",
|
||||
"pyarrow>=18.1.0",
|
||||
"rechunker>=0.5.2",
|
||||
"requests>=2.32.3",
|
||||
"rich>=14.0.0",
|
||||
"rioxarray>=0.19.0",
|
||||
|
|
@ -66,7 +65,9 @@ dependencies = [
|
|||
"pypalettes>=0.2.1,<0.3",
|
||||
"ty>=0.0.2,<0.0.3",
|
||||
"ruff>=0.14.9,<0.15",
|
||||
"pandas-stubs>=2.3.3.251201,<3", "pytest>=9.0.2,<10",
|
||||
"pandas-stubs>=2.3.3.251201,<3",
|
||||
"pytest>=9.0.2,<10",
|
||||
"autogluon-tabular[all]>=1.5.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
@ -90,15 +91,15 @@ url = "https://pypi.nvidia.com"
|
|||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
|
||||
# entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
|
||||
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
|
||||
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
||||
xdem = { git = "https://github.com/GlacioHack/xdem" }
|
||||
xdggs = { git = "https://github.com/relativityhd/xdggs", branch = "feature/make-plotting-useful" }
|
||||
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
|
||||
cudf-cu12 = { index = "nvidia" }
|
||||
cuml-cu12 = { index = "nvidia" }
|
||||
cuspatial-cu12 = { index = "nvidia" }
|
||||
# cudf-cu12 = { index = "nvidia" }
|
||||
# cuml-cu12 = { index = "nvidia" }
|
||||
# cuspatial-cu12 = { index = "nvidia" }
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
|
@ -148,5 +149,10 @@ nccl = ">=2.27.7.1,<3"
|
|||
cudnn = ">=9.13.1.26,<10"
|
||||
cusparselt = ">=0.8.1.1,<0.9"
|
||||
cuda-version = "12.9.*"
|
||||
rapids = ">=25.10.0,<26"
|
||||
# cudf = ">=25.10.0,<26"
|
||||
# cuml = ">=25.10.0,<26"
|
||||
healpix-geo = ">=0.0.6"
|
||||
scikit-learn = ">=1.4.0,<1.8.0"
|
||||
pyarrow = ">=7.0.0,<21.0.0"
|
||||
cudf = ">=25.12.0,<26"
|
||||
cuml = ">=25.12.0,<26"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import xarray as xr
|
||||
import zarr
|
||||
from rich import print
|
||||
import dask.distributed as dd
|
||||
import xarray as xr
|
||||
from rich import print
|
||||
|
||||
from entropice.utils.paths import get_era5_stores
|
||||
import entropice.utils.codecs
|
||||
from entropice.utils.paths import get_era5_stores
|
||||
|
||||
|
||||
def print_info(daily_raw=None, show_vars: bool = True):
|
||||
if daily_raw is None:
|
||||
|
|
@ -30,29 +30,109 @@ def print_info(daily_raw = None, show_vars: bool = True):
|
|||
print(da.encoding)
|
||||
print("")
|
||||
|
||||
def rechunk():
|
||||
|
||||
def rechunk(use_shards: bool = False):
|
||||
if use_shards:
|
||||
# ! MEEEP: https://github.com/pydata/xarray/issues/10831
|
||||
print("WARNING! Rechunking with shards. This is known to be broken in xarray/dask!")
|
||||
with (
|
||||
dd.LocalCluster(n_workers=8, threads_per_worker=5, memory_limit="20GB") as cluster,
|
||||
dd.Client(cluster) as client,
|
||||
):
|
||||
print(f"Dashboard: {client.dashboard_link}")
|
||||
daily_store = get_era5_stores("daily")
|
||||
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked_sharded")
|
||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
||||
print_info(daily_raw, False)
|
||||
daily_raw = daily_raw.chunk({
|
||||
daily_raw = daily_raw.chunk(
|
||||
{
|
||||
"time": 120,
|
||||
"latitude": -1, # Should be 337,
|
||||
"longitude": -1 # Should be 3600
|
||||
})
|
||||
print_info(daily_raw, False)
|
||||
"longitude": -1, # Should be 3600
|
||||
}
|
||||
)
|
||||
encoding = entropice.utils.codecs.from_ds(daily_raw, filter_existing=False)
|
||||
for var in daily_raw.data_vars:
|
||||
encoding[var]["chunks"] = (120, 337, 3600)
|
||||
if use_shards:
|
||||
encoding[var]["shards"] = (1200, 337, 3600)
|
||||
print(encoding)
|
||||
daily_raw.to_zarr(daily_store_rechunked, mode="w", consolidated=False, encoding=encoding)
|
||||
|
||||
|
||||
def validate():
|
||||
daily_store = get_era5_stores("daily")
|
||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
||||
|
||||
encoding = entropice.utils.codecs.from_ds(daily_raw)
|
||||
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
|
||||
daily_raw.to_zarr(daily_store_rechunked, mode="w", encoding=encoding, consolidated=False)
|
||||
daily_rechunked = xr.open_zarr(daily_store_rechunked, consolidated=False)
|
||||
|
||||
print("\n=== Comparing Datasets ===")
|
||||
|
||||
# Compare dimensions
|
||||
if daily_raw.sizes != daily_rechunked.sizes:
|
||||
print("❌ Dimensions differ:")
|
||||
print(f" Original: {daily_raw.sizes}")
|
||||
print(f" Rechunked: {daily_rechunked.sizes}")
|
||||
else:
|
||||
print("✅ Dimensions match")
|
||||
|
||||
# Compare variables
|
||||
raw_vars = set(daily_raw.data_vars)
|
||||
rechunked_vars = set(daily_rechunked.data_vars)
|
||||
if raw_vars != rechunked_vars:
|
||||
print("❌ Variables differ:")
|
||||
print(f" Only in original: {raw_vars - rechunked_vars}")
|
||||
print(f" Only in rechunked: {rechunked_vars - raw_vars}")
|
||||
else:
|
||||
print("✅ Variables match")
|
||||
|
||||
# Compare each variable
|
||||
print("\n=== Variable Comparison ===")
|
||||
all_equal = True
|
||||
for var in raw_vars & rechunked_vars:
|
||||
raw_var = daily_raw[var]
|
||||
rechunked_var = daily_rechunked[var]
|
||||
|
||||
if raw_var.equals(rechunked_var):
|
||||
print(f"✅ {var}: Equal")
|
||||
else:
|
||||
all_equal = False
|
||||
print(f"❌ {var}: NOT Equal")
|
||||
|
||||
# Check if values are equal
|
||||
try:
|
||||
values_equal = raw_var.values.shape == rechunked_var.values.shape
|
||||
if values_equal:
|
||||
import numpy as np
|
||||
|
||||
values_equal = np.allclose(raw_var.values, rechunked_var.values, equal_nan=True)
|
||||
|
||||
if values_equal:
|
||||
print(" → Values are numerically equal (likely metadata/encoding difference)")
|
||||
else:
|
||||
print(" → Values differ!")
|
||||
print(f" Original shape: {raw_var.values.shape}")
|
||||
print(f" Rechunked shape: {rechunked_var.values.shape}")
|
||||
except Exception as e:
|
||||
print(f" → Error comparing values: {e}")
|
||||
|
||||
# Check attributes
|
||||
if raw_var.attrs != rechunked_var.attrs:
|
||||
print(" → Attributes differ:")
|
||||
print(f" Original: {raw_var.attrs}")
|
||||
print(f" Rechunked: {rechunked_var.attrs}")
|
||||
|
||||
# Check encoding
|
||||
if raw_var.encoding != rechunked_var.encoding:
|
||||
print(" → Encoding differs:")
|
||||
print(f" Original: {raw_var.encoding}")
|
||||
print(f" Rechunked: {rechunked_var.encoding}")
|
||||
|
||||
if all_equal:
|
||||
print("\n✅ Validation successful: All datasets are equal.")
|
||||
else:
|
||||
print("\n❌ Validation failed: Datasets have differences (see above).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with (
|
||||
dd.LocalCluster(n_workers=1, threads_per_worker=10, memory_limit="100GB") as cluster,
|
||||
dd.Client(cluster) as client,
|
||||
):
|
||||
print(client)
|
||||
print(client.dashboard_link)
|
||||
rechunk()
|
||||
print("Done.")
|
||||
validate()
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import pydeck as pdk
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
|
||||
from entropice.dashboard.utils.class_ordering import get_ordered_classes
|
||||
from entropice.dashboard.utils.colors import get_cmap, get_palette
|
||||
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.ml.training import TrainingSettings
|
||||
|
||||
|
||||
|
|
@ -1154,73 +1154,36 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
|
|||
|
||||
|
||||
@st.fragment
|
||||
def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
||||
"""Render 3D pydeck map showing prediction results.
|
||||
def render_confusion_matrix_map(
|
||||
result_path: Path, settings: TrainingSettings, merged_predictions: gpd.GeoDataFrame | None = None
|
||||
):
|
||||
"""Render 3D pydeck map showing model performance on training data.
|
||||
|
||||
Uses true labels for elevation (height) and different shades of red for incorrect predictions
|
||||
based on the predicted class.
|
||||
Displays cells from the training dataset with predictions, colored by correctness.
|
||||
Uses true labels for elevation (height) and different shades of red for incorrect predictions.
|
||||
|
||||
Args:
|
||||
result_path: Path to the training result directory.
|
||||
result_path: Path to the training result directory (not used, kept for compatibility).
|
||||
settings: Settings dictionary containing grid, level, task, and target information.
|
||||
merged_predictions: GeoDataFrame with predictions, true labels, and split info.
|
||||
|
||||
"""
|
||||
st.subheader("🗺️ Prediction Results Map")
|
||||
|
||||
# Load predicted probabilities
|
||||
preds_file = result_path / "predicted_probabilities.parquet"
|
||||
if not preds_file.exists():
|
||||
st.warning("No predicted probabilities found for this training run.")
|
||||
if merged_predictions is None:
|
||||
st.warning("Prediction data not available. Cannot display map.")
|
||||
return
|
||||
|
||||
preds_gdf = gpd.read_parquet(preds_file)
|
||||
|
||||
# Get task and target information from settings
|
||||
task = settings.task
|
||||
target = settings.target
|
||||
# Get grid type and task from settings
|
||||
grid = settings.grid
|
||||
level = settings.level
|
||||
task = settings.task
|
||||
|
||||
# Create dataset ensemble to get true labels
|
||||
# We need to load the target data to get true labels
|
||||
try:
|
||||
ensemble = DatasetEnsemble(
|
||||
grid=grid,
|
||||
level=level,
|
||||
target=target,
|
||||
members=[], # We don't need feature data, just target
|
||||
)
|
||||
training_data = ensemble.create_cat_training_dataset(task=task, device="cpu")
|
||||
except Exception as e:
|
||||
st.error(f"Error loading training data: {e}")
|
||||
return
|
||||
|
||||
# Get all cells from the complete dataset (not just test split)
|
||||
# Use the full dataset which includes both train and test splits
|
||||
all_cells = training_data.dataset.copy()
|
||||
|
||||
# Merge predictions with true labels
|
||||
# Reset index to avoid ambiguity between index and column
|
||||
labeled_gdf = all_cells.reset_index().rename(columns={"index": "cell_id"})
|
||||
labeled_gdf["true_class"] = training_data.y.binned.loc[all_cells.index].to_numpy()
|
||||
|
||||
# Merge with predictions - use left join to keep all cells
|
||||
merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="left")
|
||||
merged = gpd.GeoDataFrame(merged_df, geometry="geometry", crs=labeled_gdf.crs)
|
||||
|
||||
# Mark which cells have predictions (test split) vs not (training split)
|
||||
merged["in_test_split"] = merged["predicted_class"].notna()
|
||||
|
||||
# For cells without predictions (training split), use true class as predicted class for visualization
|
||||
merged["predicted_class"] = merged["predicted_class"].fillna(merged["true_class"])
|
||||
# Use the merged predictions which already have true labels, predictions, and split info
|
||||
merged = merged_predictions.copy()
|
||||
merged["is_correct"] = merged["true_class"] == merged["predicted_class"]
|
||||
|
||||
if len(merged) == 0:
|
||||
st.warning("No matching predictions found for labeled cells.")
|
||||
st.warning("No predictions found for labeled cells.")
|
||||
return
|
||||
|
||||
# Mark correct vs incorrect predictions (only meaningful for test split)
|
||||
merged["is_correct"] = merged["true_class"] == merged["predicted_class"]
|
||||
|
||||
# Get ordered class labels for the task
|
||||
ordered_classes = get_ordered_classes(task)
|
||||
|
||||
|
|
@ -1228,51 +1191,55 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
|||
col1, col2, col3 = st.columns([2, 1, 1])
|
||||
|
||||
with col1:
|
||||
# Filter by prediction correctness and split
|
||||
categories = ["All", "Test Split Only", "Training Split Only", "Correct (Test)", "Incorrect (Test)"]
|
||||
|
||||
selected_category = st.selectbox(
|
||||
"Filter by Category",
|
||||
options=categories,
|
||||
key="confusion_map_category",
|
||||
# Split selector (similar to confusion matrix)
|
||||
split_type = st.selectbox(
|
||||
"Select Data Split",
|
||||
options=["test", "train", "all"],
|
||||
format_func=lambda x: {"test": "Test Set", "train": "Training Set (CV)", "all": "All Data"}[x],
|
||||
help="Choose which data split to display on the map",
|
||||
key="prediction_map_split_select",
|
||||
)
|
||||
|
||||
with col2:
|
||||
# Color scheme selector
|
||||
show_only_incorrect = st.checkbox(
|
||||
"Highlight Errors Only",
|
||||
value=False,
|
||||
help="Show only incorrect predictions in red, hide correct ones",
|
||||
key="prediction_map_errors_only",
|
||||
)
|
||||
|
||||
with col3:
|
||||
opacity = st.slider(
|
||||
"Opacity",
|
||||
min_value=0.1,
|
||||
max_value=1.0,
|
||||
value=0.7,
|
||||
step=0.1,
|
||||
key="confusion_map_opacity",
|
||||
key="prediction_map_opacity",
|
||||
)
|
||||
|
||||
with col3:
|
||||
line_width = st.slider(
|
||||
"Line Width",
|
||||
min_value=0.5,
|
||||
max_value=3.0,
|
||||
value=1.0,
|
||||
step=0.5,
|
||||
key="confusion_map_line_width",
|
||||
)
|
||||
|
||||
# Filter data if needed
|
||||
if selected_category == "Test Split Only":
|
||||
display_gdf = merged[merged["in_test_split"]].copy()
|
||||
elif selected_category == "Training Split Only":
|
||||
display_gdf = merged[~merged["in_test_split"]].copy()
|
||||
elif selected_category == "Correct (Test)":
|
||||
display_gdf = merged[merged["is_correct"] & merged["in_test_split"]].copy()
|
||||
elif selected_category == "Incorrect (Test)":
|
||||
display_gdf = merged[~merged["is_correct"] & merged["in_test_split"]].copy()
|
||||
else: # "All"
|
||||
# Filter data by split
|
||||
if split_type == "test":
|
||||
display_gdf = merged[merged["split"] == "test"].copy()
|
||||
split_caption = "Test Set (held-out data)"
|
||||
elif split_type == "train":
|
||||
display_gdf = merged[merged["split"] == "train"].copy()
|
||||
split_caption = "Training Set (CV data)"
|
||||
else: # "all"
|
||||
display_gdf = merged.copy()
|
||||
split_caption = "All Available Data"
|
||||
|
||||
# Optionally filter to show only incorrect predictions
|
||||
if show_only_incorrect:
|
||||
display_gdf = display_gdf[~display_gdf["is_correct"]].copy()
|
||||
|
||||
if len(display_gdf) == 0:
|
||||
st.warning(f"No cells found for category: {selected_category}")
|
||||
st.warning(f"No cells found for {split_caption}.")
|
||||
return
|
||||
|
||||
st.caption(f"📍 Showing {len(display_gdf)} cells from {split_caption}")
|
||||
|
||||
# Convert to WGS84 for pydeck
|
||||
display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
|
||||
|
||||
|
|
@ -1303,12 +1270,12 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
|||
|
||||
display_gdf_wgs84["fill_color"] = display_gdf_wgs84.apply(get_color, axis=1)
|
||||
|
||||
# Add line color based on split: blue for test split, orange for training split
|
||||
# Add line color based on split: blue for test, orange for train
|
||||
def get_line_color(row):
|
||||
if row["in_test_split"]:
|
||||
if row["split"] == "test":
|
||||
return [52, 152, 219] # Blue for test split
|
||||
else:
|
||||
return [230, 126, 34] # Orange for training split
|
||||
return [230, 126, 34] # Orange for train split
|
||||
|
||||
display_gdf_wgs84["line_color"] = display_gdf_wgs84.apply(get_line_color, axis=1)
|
||||
|
||||
|
|
@ -1329,18 +1296,15 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
|||
geojson_data = []
|
||||
for _, row in display_gdf_wgs84.iterrows():
|
||||
# Determine split and status for tooltip
|
||||
split_name = "Test Split" if row["in_test_split"] else "Training Split"
|
||||
if row["in_test_split"]:
|
||||
split_name = "Test" if row["split"] == "test" else "Training (CV)"
|
||||
status = "✓ Correct" if row["is_correct"] else "✗ Incorrect"
|
||||
else:
|
||||
status = "(No prediction - training data)"
|
||||
|
||||
feature = {
|
||||
"type": "Feature",
|
||||
"geometry": row["geometry"].__geo_interface__,
|
||||
"properties": {
|
||||
"true_class": str(row["true_class"]),
|
||||
"predicted_class": str(row["predicted_class"]) if row["in_test_split"] else "N/A",
|
||||
"true_label": str(row["true_class"]),
|
||||
"predicted_label": str(row["predicted_class"]),
|
||||
"is_correct": bool(row["is_correct"]),
|
||||
"split": split_name,
|
||||
"status": status,
|
||||
|
|
@ -1362,7 +1326,7 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
|||
wireframe=False,
|
||||
get_fill_color="properties.fill_color",
|
||||
get_line_color="properties.line_color",
|
||||
line_width_min_pixels=line_width,
|
||||
line_width_min_pixels=2,
|
||||
get_elevation="properties.elevation",
|
||||
elevation_scale=500000,
|
||||
pickable=True,
|
||||
|
|
@ -1376,10 +1340,10 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
|||
layers=[layer],
|
||||
initial_view_state=view_state,
|
||||
tooltip={
|
||||
"html": "<b>Split:</b> {split}<br/>"
|
||||
"<b>True Label:</b> {true_class}<br/>"
|
||||
"<b>Predicted Label:</b> {predicted_class}<br/>"
|
||||
"<b>Status:</b> {status}",
|
||||
"html": "<b>Status:</b> {status}<br/>"
|
||||
"<b>True Label:</b> {true_label}<br/>"
|
||||
"<b>Predicted Label:</b> {predicted_label}<br/>"
|
||||
"<b>Split:</b> {split}",
|
||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||
},
|
||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||
|
|
@ -1388,83 +1352,55 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
|||
# Render the map
|
||||
st.pydeck_chart(deck)
|
||||
|
||||
# Show statistics
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
# Show statistics for displayed data
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Labeled Cells", len(merged))
|
||||
st.metric("Cells Displayed", len(display_gdf))
|
||||
|
||||
with col2:
|
||||
test_count = len(merged[merged["in_test_split"]])
|
||||
st.metric("Test Split", test_count)
|
||||
correct = len(display_gdf[display_gdf["is_correct"]])
|
||||
st.metric("Correct Predictions", correct)
|
||||
|
||||
with col3:
|
||||
train_count = len(merged[~merged["in_test_split"]])
|
||||
st.metric("Training Split", train_count)
|
||||
|
||||
with col4:
|
||||
test_cells = merged[merged["in_test_split"]]
|
||||
if len(test_cells) > 0:
|
||||
correct = len(test_cells[test_cells["is_correct"]])
|
||||
accuracy = correct / len(test_cells)
|
||||
st.metric("Test Accuracy", f"{accuracy:.2%}")
|
||||
if len(display_gdf) > 0:
|
||||
accuracy = correct / len(display_gdf)
|
||||
st.metric("Accuracy", f"{accuracy:.2%}")
|
||||
else:
|
||||
st.metric("Test Accuracy", "N/A")
|
||||
st.metric("Accuracy", "N/A")
|
||||
|
||||
# Add legend
|
||||
with st.expander("Legend", expanded=True):
|
||||
# Split indicators (border colors)
|
||||
st.markdown("**Data Split (Border Color):**")
|
||||
st.markdown("**Fill Color (Prediction Correctness):**")
|
||||
|
||||
test_count = len(merged[merged["in_test_split"]])
|
||||
train_count = len(merged[~merged["in_test_split"]])
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: #333; '
|
||||
f'border: 2px solid rgb(52, 152, 219); margin-right: 8px; flex-shrink: 0;"></div>'
|
||||
f"<span><b>Test Split</b> ({test_count} cells, {test_count / len(merged) * 100:.1f}%)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 12px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: #333; '
|
||||
f'border: 2px solid rgb(230, 126, 34); margin-right: 8px; flex-shrink: 0;"></div>'
|
||||
f"<span><b>Training Split</b> ({train_count} cells, {train_count / len(merged) * 100:.1f}%)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.markdown("---")
|
||||
st.markdown("**Fill Color (Prediction Results):**")
|
||||
|
||||
# Correct predictions (test split only)
|
||||
test_cells = merged[merged["in_test_split"]]
|
||||
correct = len(test_cells[test_cells["is_correct"]]) if len(test_cells) > 0 else 0
|
||||
incorrect = len(test_cells[~test_cells["is_correct"]]) if len(test_cells) > 0 else 0
|
||||
# Correct predictions
|
||||
correct_count = len(display_gdf[display_gdf["is_correct"]])
|
||||
incorrect_count = len(display_gdf[~display_gdf["is_correct"]])
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 8px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: rgb(46, 204, 113); '
|
||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||
f"<span><b>Correct Predictions (Test)</b> ({correct} cells, {correct / len(test_cells) * 100 if len(test_cells) > 0 else 0:.1f}%)</span></div>",
|
||||
f"<span><b>Correct Predictions</b> ({correct_count} cells, "
|
||||
f"{correct_count / len(display_gdf) * 100 if len(display_gdf) > 0 else 0:.1f}%)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# Incorrect predictions by predicted class (shades of red)
|
||||
if incorrect_count > 0:
|
||||
st.markdown(
|
||||
f"<b>Incorrect Predictions by Predicted Class (Test)</b> ({incorrect} cells):", unsafe_allow_html=True
|
||||
f"<b>Incorrect Predictions by Predicted Class</b> ({incorrect_count} cells):", unsafe_allow_html=True
|
||||
)
|
||||
|
||||
for class_idx, class_label in enumerate(ordered_classes):
|
||||
# Get count of incorrect predictions for this predicted class (test split only)
|
||||
count = len(test_cells[(~test_cells["is_correct"]) & (test_cells["predicted_class"] == class_label)])
|
||||
# Get count of incorrect predictions for this predicted class
|
||||
count = len(display_gdf[(~display_gdf["is_correct"]) & (display_gdf["predicted_class"] == class_label)])
|
||||
if count > 0:
|
||||
# Get color for this predicted class
|
||||
color_value = red_cmap(class_idx / max(n_classes - 1, 1))
|
||||
rgb = [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)]
|
||||
|
||||
percentage = count / incorrect * 100 if incorrect > 0 else 0
|
||||
percentage = count / incorrect_count * 100
|
||||
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px; margin-left: 20px;">'
|
||||
|
|
@ -1474,16 +1410,33 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
|||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# Note about training split
|
||||
st.markdown("---")
|
||||
st.markdown("**Border Color (Data Split):**")
|
||||
|
||||
# Count by split in displayed data
|
||||
test_in_display = len(display_gdf[display_gdf["split"] == "test"])
|
||||
train_in_display = len(display_gdf[display_gdf["split"] == "train"])
|
||||
|
||||
if test_in_display > 0:
|
||||
st.markdown(
|
||||
f'<div style="margin-top: 8px; font-style: italic; color: #888;">'
|
||||
f"Note: Training split cells ({train_count}) are shown with their true labels (green fill) "
|
||||
f"since predictions are only available for the test split.</div>",
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: #333; '
|
||||
f'border: 2px solid rgb(52, 152, 219); margin-right: 8px; flex-shrink: 0;"></div>'
|
||||
f"<span><b>Test Split</b> ({test_in_display} cells)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
if train_in_display > 0:
|
||||
st.markdown(
|
||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||
f'<div style="width: 20px; height: 20px; background-color: #333; '
|
||||
f'border: 2px solid rgb(230, 126, 34); margin-right: 8px; flex-shrink: 0;"></div>'
|
||||
f"<span><b>Training Split</b> ({train_in_display} cells)</span></div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.markdown("---")
|
||||
st.markdown("**Elevation (3D):**")
|
||||
st.markdown("**Elevation (3D Height):**")
|
||||
|
||||
# Show elevation mapping for each true class
|
||||
st.markdown("Height represents the <b>true label</b>:", unsafe_allow_html=True)
|
||||
|
|
@ -1514,13 +1467,23 @@ def render_confusion_matrix_heatmap(confusion_matrix: "xr.DataArray", task: str)
|
|||
true_labels = confusion_matrix.coords["true_label"].values
|
||||
pred_labels = confusion_matrix.coords["predicted_label"].values
|
||||
|
||||
# For binary classification, map 0/1 to No-RTS/RTS
|
||||
if task == "binary":
|
||||
# Check if labels are already strings (from predictions) or numeric (from stored confusion matrices)
|
||||
first_true_label = true_labels[0]
|
||||
is_string_labels = isinstance(first_true_label, str) or (
|
||||
hasattr(first_true_label, "dtype") and first_true_label.dtype.kind in ("U", "O")
|
||||
)
|
||||
|
||||
if is_string_labels:
|
||||
# Labels are already string labels, use them directly
|
||||
true_labels_str = [str(label) for label in true_labels]
|
||||
pred_labels_str = [str(label) for label in pred_labels]
|
||||
elif task == "binary":
|
||||
# Numeric binary labels - map 0/1 to No-RTS/RTS
|
||||
label_map = {0: "No-RTS", 1: "RTS"}
|
||||
true_labels_str = [label_map.get(int(label), str(label)) for label in true_labels]
|
||||
pred_labels_str = [label_map.get(int(label), str(label)) for label in pred_labels]
|
||||
else:
|
||||
# For multiclass, use numeric labels as is
|
||||
# Numeric multiclass labels - use as is
|
||||
true_labels_str = [str(label) for label in true_labels]
|
||||
pred_labels_str = [str(label) for label in pred_labels]
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ def get_members_from_settings(settings) -> list[L2SourceDataset]:
|
|||
return settings.members
|
||||
|
||||
|
||||
@st.fragment
|
||||
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
|
||||
"""Render sidebar for training run selection.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@
|
|||
|
||||
from typing import cast
|
||||
|
||||
import geopandas as gpd
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
from stopuhr import stopwatch
|
||||
|
||||
from entropice.dashboard.plots.hyperparameter_analysis import (
|
||||
|
|
@ -22,6 +24,127 @@ from entropice.dashboard.utils.stats import CVResultsStatistics
|
|||
from entropice.utils.types import GridConfig
|
||||
|
||||
|
||||
def load_predictions_with_labels(selected_result: TrainingResult) -> gpd.GeoDataFrame | None:
|
||||
"""Load predictions and merge with training data to get true labels and split info.
|
||||
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
|
||||
Returns:
|
||||
GeoDataFrame with predictions, true labels, and split information, or None if unavailable.
|
||||
|
||||
"""
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble, bin_values, taskcol
|
||||
|
||||
# Load predictions
|
||||
preds_gdf = selected_result.load_predictions()
|
||||
if preds_gdf is None:
|
||||
return None
|
||||
|
||||
# Create a minimal dataset ensemble to access target data
|
||||
settings = selected_result.settings
|
||||
dataset_ensemble = DatasetEnsemble(
|
||||
grid=settings.grid,
|
||||
level=settings.level,
|
||||
target=settings.target,
|
||||
members=[], # No feature data needed, just targets
|
||||
)
|
||||
|
||||
# Load target dataset (just labels, no features)
|
||||
with st.spinner("Loading target labels..."):
|
||||
targets = dataset_ensemble._read_target()
|
||||
|
||||
# Get coverage and task columns
|
||||
task_col = taskcol[settings.task][settings.target]
|
||||
|
||||
# Filter for valid labels (same as in _cat_and_split)
|
||||
valid_labels = targets[task_col].notna()
|
||||
filtered_targets = targets.loc[valid_labels].copy()
|
||||
|
||||
# Apply binning to get class labels (same logic as _cat_and_split)
|
||||
if settings.task == "binary":
|
||||
binned = filtered_targets[task_col].map({False: "No RTS", True: "RTS"}).astype("category")
|
||||
elif settings.task == "count":
|
||||
binned = bin_values(filtered_targets[task_col].astype(int), task=settings.task)
|
||||
elif settings.task == "density":
|
||||
binned = bin_values(filtered_targets[task_col], task=settings.task)
|
||||
else:
|
||||
raise ValueError(f"Invalid task: {settings.task}")
|
||||
|
||||
filtered_targets["true_class"] = binned.to_numpy()
|
||||
|
||||
# Recreate the train/test split deterministically (same random_state=42 as in _cat_and_split)
|
||||
_train_idx, test_idx = train_test_split(
|
||||
filtered_targets.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True
|
||||
)
|
||||
filtered_targets["split"] = "train"
|
||||
filtered_targets.loc[test_idx, "split"] = "test"
|
||||
filtered_targets["split"] = filtered_targets["split"].astype("category")
|
||||
|
||||
# Ensure cell_id is available as a column for merging
|
||||
# Check if cell_id already exists, otherwise use the index
|
||||
if "cell_id" not in filtered_targets.columns:
|
||||
filtered_targets = filtered_targets.reset_index().rename(columns={"index": "cell_id"})
|
||||
|
||||
# Merge predictions with labels (inner join to keep only cells with predictions)
|
||||
merged = filtered_targets.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner")
|
||||
merged_gdf = gpd.GeoDataFrame(merged, geometry="geometry", crs=targets.crs)
|
||||
|
||||
return merged_gdf
|
||||
|
||||
|
||||
def compute_confusion_matrix_from_merged_data(
|
||||
merged_data: gpd.GeoDataFrame,
|
||||
split_type: str,
|
||||
label_names: list[str],
|
||||
) -> xr.DataArray | None:
|
||||
"""Compute confusion matrix from merged predictions and labels.
|
||||
|
||||
Args:
|
||||
merged_data: GeoDataFrame with 'true_class', 'predicted_class', and 'split' columns.
|
||||
split_type: One of 'test', 'train', or 'all'.
|
||||
label_names: List of class label names in order.
|
||||
|
||||
Returns:
|
||||
xarray.DataArray with confusion matrix or None if data unavailable.
|
||||
|
||||
"""
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
# Filter by split type
|
||||
if split_type == "train":
|
||||
data = merged_data[merged_data["split"] == "train"]
|
||||
elif split_type == "test":
|
||||
data = merged_data[merged_data["split"] == "test"]
|
||||
elif split_type == "all":
|
||||
data = merged_data
|
||||
else:
|
||||
raise ValueError(f"Invalid split_type: {split_type}")
|
||||
|
||||
if len(data) == 0:
|
||||
st.warning(f"No data available for {split_type} split.")
|
||||
return None
|
||||
|
||||
# Get true and predicted labels
|
||||
y_true = data["true_class"].to_numpy()
|
||||
y_pred = data["predicted_class"].to_numpy()
|
||||
|
||||
# Compute confusion matrix
|
||||
cm = confusion_matrix(y_true, y_pred, labels=label_names)
|
||||
|
||||
# Create xarray DataArray
|
||||
cm_xr = xr.DataArray(
|
||||
cm,
|
||||
dims=["true_label", "predicted_label"],
|
||||
coords={"true_label": label_names, "predicted_label": label_names},
|
||||
name="confusion_matrix",
|
||||
)
|
||||
|
||||
return cm_xr
|
||||
|
||||
|
||||
def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str, int]:
|
||||
"""Render sidebar for training run and analysis settings selection.
|
||||
|
||||
|
|
@ -233,21 +356,56 @@ def render_cv_statistics_section(selected_result, selected_metric):
|
|||
)
|
||||
|
||||
|
||||
def render_confusion_matrix_section(selected_result: TrainingResult):
|
||||
@st.fragment
|
||||
def render_confusion_matrix_section(selected_result: TrainingResult, merged_predictions: gpd.GeoDataFrame | None):
|
||||
"""Render confusion matrix visualization and analysis.
|
||||
|
||||
Args:
|
||||
selected_result: The selected TrainingResult object.
|
||||
merged_predictions: GeoDataFrame with predictions merged with true labels and split info.
|
||||
|
||||
"""
|
||||
st.header("🎲 Confusion Matrix")
|
||||
st.caption("Detailed breakdown of predictions on the test set")
|
||||
st.caption("Detailed breakdown of predictions")
|
||||
|
||||
# Add selector for confusion matrix type
|
||||
cm_type = st.selectbox(
|
||||
"Select Data Split",
|
||||
options=["test", "train", "all"],
|
||||
format_func=lambda x: {"test": "Test Set", "train": "CV Set (Train Split)", "all": "All Available Data"}[x],
|
||||
help="Choose which data split to display the confusion matrix for",
|
||||
key="cm_split_select",
|
||||
)
|
||||
|
||||
# Get label names from settings
|
||||
label_names = selected_result.settings.classes
|
||||
|
||||
# Compute or load confusion matrix based on selection
|
||||
if cm_type == "test":
|
||||
if selected_result.confusion_matrix is None:
|
||||
st.warning("No confusion matrix available for this training run.")
|
||||
st.warning("No confusion matrix available for the test set.")
|
||||
return
|
||||
cm = selected_result.confusion_matrix
|
||||
st.info("📊 Showing confusion matrix for the **Test Set** (held-out data, never used during training)")
|
||||
else:
|
||||
if merged_predictions is None:
|
||||
st.warning("Predictions data not available. Cannot compute confusion matrix.")
|
||||
return
|
||||
|
||||
render_confusion_matrix_heatmap(selected_result.confusion_matrix, selected_result.settings.task)
|
||||
with st.spinner(f"Computing confusion matrix for {cm_type} split..."):
|
||||
cm = compute_confusion_matrix_from_merged_data(merged_predictions, cm_type, label_names)
|
||||
if cm is None:
|
||||
return
|
||||
|
||||
if cm_type == "train":
|
||||
st.info(
|
||||
"📊 Showing confusion matrix for the **CV Set (Train Split)** "
|
||||
"(data used during hyperparameter search cross-validation)"
|
||||
)
|
||||
else: # all
|
||||
st.info("📊 Showing confusion matrix for **All Available Data** (combined train and test splits)")
|
||||
|
||||
render_confusion_matrix_heatmap(cm, selected_result.settings.task)
|
||||
|
||||
|
||||
def render_parameter_space_section(selected_result, selected_metric):
|
||||
|
|
@ -374,6 +532,9 @@ def render_training_analysis_page():
|
|||
return
|
||||
selected_result, selected_metric, refit_metric, top_n = selection_result
|
||||
|
||||
# Load predictions with labels once (used by confusion matrix and map)
|
||||
merged_predictions = load_predictions_with_labels(selected_result)
|
||||
|
||||
# Main content area
|
||||
results = selected_result.results
|
||||
settings = selected_result.settings
|
||||
|
|
@ -389,7 +550,7 @@ def render_training_analysis_page():
|
|||
st.divider()
|
||||
|
||||
# Confusion Matrix Section
|
||||
render_confusion_matrix_section(selected_result)
|
||||
render_confusion_matrix_section(selected_result, merged_predictions)
|
||||
|
||||
st.divider()
|
||||
|
||||
|
|
@ -400,9 +561,10 @@ def render_training_analysis_page():
|
|||
|
||||
st.divider()
|
||||
|
||||
# Confusion Matrix Map Section
|
||||
st.header("🗺️ Prediction Results Map")
|
||||
render_confusion_matrix_map(selected_result.path, settings)
|
||||
# Prediction Analysis Map Section
|
||||
st.header("🗺️ Model Performance Map")
|
||||
st.caption("Interactive 3D map showing prediction correctness across the training dataset")
|
||||
render_confusion_matrix_map(selected_result.path, settings, merged_predictions)
|
||||
|
||||
st.divider()
|
||||
|
||||
|
|
|
|||
|
|
@ -734,14 +734,14 @@ def spatial_agg(
|
|||
3: _Aggregations.common(),
|
||||
4: _Aggregations.common(),
|
||||
5: _Aggregations(mean=True),
|
||||
6: "interpolate",
|
||||
6: "interpolate", # nearest neighbor interpolation
|
||||
},
|
||||
"healpix": {
|
||||
6: _Aggregations.common(),
|
||||
7: _Aggregations.common(),
|
||||
8: _Aggregations(mean=True),
|
||||
9: _Aggregations(mean=True),
|
||||
10: "interpolate",
|
||||
10: "interpolate", # nearest neighbor interpolation
|
||||
},
|
||||
}
|
||||
aggregations = aggregations_by_gridlevel[grid][level]
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import cupy as cp
|
|||
import cyclopts
|
||||
import pandas as pd
|
||||
import toml
|
||||
import torch
|
||||
import xarray as xr
|
||||
from array_api_compat import get_namespace
|
||||
from cuml.ensemble import RandomForestClassifier
|
||||
from cuml.neighbors import KNeighborsClassifier
|
||||
from entropy import ESPAClassifier
|
||||
|
|
@ -233,10 +233,9 @@ def random_cv(
|
|||
# Compute predictions on the test set
|
||||
y_pred = best_estimator.predict(training_data.X.test)
|
||||
labels = list(range(len(training_data.y.labels)))
|
||||
xp = get_namespace(y_test)
|
||||
y_test = xp.as_array(y_test)
|
||||
y_pred = xp.as_array(y_pred)
|
||||
labels = xp.as_array(labels)
|
||||
y_test = torch.asarray(y_test, device="cuda")
|
||||
y_pred = torch.asarray(y_pred, device="cuda")
|
||||
labels = torch.asarray(labels, device="cuda")
|
||||
|
||||
test_metrics = {metric: _metric_functions[metric](y_test, y_pred) for metric in metrics}
|
||||
|
||||
|
|
@ -244,7 +243,7 @@ def random_cv(
|
|||
cm = confusion_matrix(y_test, y_pred, labels=labels)
|
||||
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
|
||||
cm = xr.DataArray(
|
||||
xp.as_numpy(cm),
|
||||
cm.cpu().numpy(),
|
||||
dims=["true_label", "predicted_label"],
|
||||
coords={"true_label": label_names, "predicted_label": label_names},
|
||||
name="confusion_matrix",
|
||||
|
|
|
|||
|
|
@ -368,7 +368,7 @@ def _init_worker(r: xr.Dataset | None):
|
|||
def _align_partition(
|
||||
grid_partition_gdf: gpd.GeoDataFrame,
|
||||
raster: xr.Dataset | Callable[[], xr.Dataset] | None,
|
||||
aggregations: _Aggregations | None, # None -> Interpolation
|
||||
aggregations: _Aggregations | str, # str -> Interpolation method
|
||||
pxbuffer: int,
|
||||
):
|
||||
# ? This function is expected to run inside a worker process
|
||||
|
|
@ -441,7 +441,7 @@ def _align_partition(
|
|||
)
|
||||
memprof.log_memory("After reading partial raster", log=False)
|
||||
|
||||
if aggregations is None:
|
||||
if isinstance(aggregations, str):
|
||||
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
|
||||
if grid_partition_gdf.crs.to_epsg() == 4326:
|
||||
centroids = grid_partition_gdf.geometry.apply(antimeridian.fix_shape).apply(antimeridian.centroid)
|
||||
|
|
@ -453,14 +453,26 @@ def _align_partition(
|
|||
cy = centroids.y
|
||||
interp_x = xr.DataArray(cx, dims=["cell_ids"], coords={"cell_ids": cell_ids})
|
||||
interp_y = xr.DataArray(cy, dims=["cell_ids"], coords={"cell_ids": cell_ids})
|
||||
interp_coords = (
|
||||
{"latitude": interp_y, "longitude": interp_x}
|
||||
if "latitude" in raster.dims and "longitude" in raster.dims
|
||||
else {"y": interp_y, "x": interp_x}
|
||||
)
|
||||
ydim = "latitude" if "latitude" in raster.dims else "y"
|
||||
xdim = "longitude" if "longitude" in raster.dims else "x"
|
||||
interp_coords = {ydim: interp_y, xdim: interp_x}
|
||||
|
||||
# ?: Cubic does not work with NaNs in xarray interp
|
||||
with stopwatch("Interpolating data to grid centroids", log=False):
|
||||
ongrid = partial_raster.interp(interp_coords, method="linear", kwargs={"fill_value": np.nan})
|
||||
# Fill the nan
|
||||
y_is_rev = partial_raster.indexes[ydim].is_monotonic_decreasing
|
||||
if y_is_rev:
|
||||
partial_raster = partial_raster.sortby(ydim)
|
||||
partial_raster = partial_raster.interpolate_na(
|
||||
dim=ydim,
|
||||
method=aggregations,
|
||||
).interpolate_na(
|
||||
dim=xdim,
|
||||
method=aggregations,
|
||||
)
|
||||
if y_is_rev:
|
||||
partial_raster = partial_raster.sortby(ydim)
|
||||
ongrid = partial_raster.interp(interp_coords, method=aggregations, kwargs={"fill_value": np.nan})
|
||||
memprof.log_memory("After interpolating data", log=False)
|
||||
else:
|
||||
others_shape = tuple(
|
||||
|
|
@ -519,7 +531,7 @@ def _align_partition(
|
|||
def _align_data(
|
||||
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||
aggregations: _Aggregations | None,
|
||||
aggregations: _Aggregations | str,
|
||||
n_partitions: int | None,
|
||||
concurrent_partitions: int,
|
||||
pxbuffer: int,
|
||||
|
|
@ -620,7 +632,7 @@ def _align_data(
|
|||
def aggregate_raster_into_grid(
|
||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
||||
aggregations: _Aggregations | Literal["interpolate"],
|
||||
aggregations: _Aggregations | Literal["nearest", "linear", "cubic", "interpolate"],
|
||||
grid: Grid,
|
||||
level: int,
|
||||
n_partitions: int | None = 20,
|
||||
|
|
@ -634,7 +646,11 @@ def aggregate_raster_into_grid(
|
|||
grid_gdf (gpd.GeoDataFrame | list[gpd.GeoDataFrame]): The grid to aggregate into.
|
||||
If a list of GeoDataFrames is provided, each will be processed as a separate partition.
|
||||
No further partitioning will be done and the n_partitions argument will be ignored.
|
||||
aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform.
|
||||
aggregations (_Aggregations | Literal[["nearest", "linear", "cubic", "interpolate"]):
|
||||
The aggregations to perform.
|
||||
If a string is provided, interpolation will be used with the specified method.
|
||||
If "interpolate" is provided, the nearest neighbor interpolation will be used.
|
||||
Supported methods are "nearest", "linear", and "cubic".
|
||||
grid (Grid): The type of grid to use.
|
||||
level (int): The level of the grid.
|
||||
n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20.
|
||||
|
|
@ -649,7 +665,7 @@ def aggregate_raster_into_grid(
|
|||
ongrid = _align_data(
|
||||
grid_gdf,
|
||||
raster,
|
||||
aggregations if aggregations != "interpolate" else None,
|
||||
aggregations if aggregations != "interpolate" else "nearest",
|
||||
n_partitions=n_partitions,
|
||||
concurrent_partitions=concurrent_partitions,
|
||||
pxbuffer=pxbuffer,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue