Small fixes all over the place
This commit is contained in:
parent
c92e856c55
commit
1495f71ac9
9 changed files with 3923 additions and 4084 deletions
|
|
@ -12,7 +12,7 @@ dependencies = [
|
||||||
"cartopy>=0.24.1",
|
"cartopy>=0.24.1",
|
||||||
"cdsapi>=0.7.6",
|
"cdsapi>=0.7.6",
|
||||||
"cyclopts>=4.0.0",
|
"cyclopts>=4.0.0",
|
||||||
"dask>=2025.5.1",
|
"dask>=2025.11.0",
|
||||||
"distributed>=2025.5.1",
|
"distributed>=2025.5.1",
|
||||||
"earthengine-api>=1.6.9",
|
"earthengine-api>=1.6.9",
|
||||||
"eemont>=2025.7.1",
|
"eemont>=2025.7.1",
|
||||||
|
|
@ -34,7 +34,6 @@ dependencies = [
|
||||||
"odc-geo[all]>=0.4.10",
|
"odc-geo[all]>=0.4.10",
|
||||||
"opt-einsum>=3.4.0",
|
"opt-einsum>=3.4.0",
|
||||||
"pyarrow>=18.1.0",
|
"pyarrow>=18.1.0",
|
||||||
"rechunker>=0.5.2",
|
|
||||||
"requests>=2.32.3",
|
"requests>=2.32.3",
|
||||||
"rich>=14.0.0",
|
"rich>=14.0.0",
|
||||||
"rioxarray>=0.19.0",
|
"rioxarray>=0.19.0",
|
||||||
|
|
@ -66,7 +65,9 @@ dependencies = [
|
||||||
"pypalettes>=0.2.1,<0.3",
|
"pypalettes>=0.2.1,<0.3",
|
||||||
"ty>=0.0.2,<0.0.3",
|
"ty>=0.0.2,<0.0.3",
|
||||||
"ruff>=0.14.9,<0.15",
|
"ruff>=0.14.9,<0.15",
|
||||||
"pandas-stubs>=2.3.3.251201,<3", "pytest>=9.0.2,<10",
|
"pandas-stubs>=2.3.3.251201,<3",
|
||||||
|
"pytest>=9.0.2,<10",
|
||||||
|
"autogluon-tabular[all]>=1.5.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
@ -90,15 +91,15 @@ url = "https://pypi.nvidia.com"
|
||||||
explicit = true
|
explicit = true
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
|
# entropyc = { git = "ssh://git@github.com/AlbertEMC2Stein/entropyc", branch = "refactor/tobi" }
|
||||||
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
|
entropy = { git = "ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git" }
|
||||||
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
xanimate = { git = "https://github.com/davbyr/xAnimate" }
|
||||||
xdem = { git = "https://github.com/GlacioHack/xdem" }
|
xdem = { git = "https://github.com/GlacioHack/xdem" }
|
||||||
xdggs = { git = "https://github.com/relativityhd/xdggs", branch = "feature/make-plotting-useful" }
|
xdggs = { git = "https://github.com/relativityhd/xdggs", branch = "feature/make-plotting-useful" }
|
||||||
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
|
xarray-spatial = { git = "https://github.com/relativityhd/xarray-spatial" }
|
||||||
cudf-cu12 = { index = "nvidia" }
|
# cudf-cu12 = { index = "nvidia" }
|
||||||
cuml-cu12 = { index = "nvidia" }
|
# cuml-cu12 = { index = "nvidia" }
|
||||||
cuspatial-cu12 = { index = "nvidia" }
|
# cuspatial-cu12 = { index = "nvidia" }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
|
@ -148,5 +149,10 @@ nccl = ">=2.27.7.1,<3"
|
||||||
cudnn = ">=9.13.1.26,<10"
|
cudnn = ">=9.13.1.26,<10"
|
||||||
cusparselt = ">=0.8.1.1,<0.9"
|
cusparselt = ">=0.8.1.1,<0.9"
|
||||||
cuda-version = "12.9.*"
|
cuda-version = "12.9.*"
|
||||||
rapids = ">=25.10.0,<26"
|
# cudf = ">=25.10.0,<26"
|
||||||
|
# cuml = ">=25.10.0,<26"
|
||||||
healpix-geo = ">=0.0.6"
|
healpix-geo = ">=0.0.6"
|
||||||
|
scikit-learn = ">=1.4.0,<1.8.0"
|
||||||
|
pyarrow = ">=7.0.0,<21.0.0"
|
||||||
|
cudf = ">=25.12.0,<26"
|
||||||
|
cuml = ">=25.12.0,<26"
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
import xarray as xr
|
|
||||||
import zarr
|
|
||||||
from rich import print
|
|
||||||
import dask.distributed as dd
|
import dask.distributed as dd
|
||||||
|
import xarray as xr
|
||||||
|
from rich import print
|
||||||
|
|
||||||
from entropice.utils.paths import get_era5_stores
|
|
||||||
import entropice.utils.codecs
|
import entropice.utils.codecs
|
||||||
|
from entropice.utils.paths import get_era5_stores
|
||||||
|
|
||||||
def print_info(daily_raw = None, show_vars: bool = True):
|
|
||||||
|
def print_info(daily_raw=None, show_vars: bool = True):
|
||||||
if daily_raw is None:
|
if daily_raw is None:
|
||||||
daily_store = get_era5_stores("daily")
|
daily_store = get_era5_stores("daily")
|
||||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
||||||
|
|
@ -14,12 +14,12 @@ def print_info(daily_raw = None, show_vars: bool = True):
|
||||||
print(f" Dims: {daily_raw.sizes}")
|
print(f" Dims: {daily_raw.sizes}")
|
||||||
numchunks = 1
|
numchunks = 1
|
||||||
chunksizes = {}
|
chunksizes = {}
|
||||||
approxchunksize = 4 # 4 Bytes = float32
|
approxchunksize = 4 # 4 Bytes = float32
|
||||||
for d, cs in daily_raw.chunksizes.items():
|
for d, cs in daily_raw.chunksizes.items():
|
||||||
numchunks *= len(cs)
|
numchunks *= len(cs)
|
||||||
chunksizes[d] = max(cs)
|
chunksizes[d] = max(cs)
|
||||||
approxchunksize *= max(cs)
|
approxchunksize *= max(cs)
|
||||||
approxchunksize /= 10e6 # MB
|
approxchunksize /= 10e6 # MB
|
||||||
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
|
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
|
||||||
print(f" Encoding: {daily_raw.encoding}")
|
print(f" Encoding: {daily_raw.encoding}")
|
||||||
if show_vars:
|
if show_vars:
|
||||||
|
|
@ -30,29 +30,109 @@ def print_info(daily_raw = None, show_vars: bool = True):
|
||||||
print(da.encoding)
|
print(da.encoding)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
def rechunk():
|
|
||||||
|
def rechunk(use_shards: bool = False):
|
||||||
|
if use_shards:
|
||||||
|
# ! MEEEP: https://github.com/pydata/xarray/issues/10831
|
||||||
|
print("WARNING! Rechunking with shards. This is known to be broken in xarray/dask!")
|
||||||
|
with (
|
||||||
|
dd.LocalCluster(n_workers=8, threads_per_worker=5, memory_limit="20GB") as cluster,
|
||||||
|
dd.Client(cluster) as client,
|
||||||
|
):
|
||||||
|
print(f"Dashboard: {client.dashboard_link}")
|
||||||
|
daily_store = get_era5_stores("daily")
|
||||||
|
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked_sharded")
|
||||||
|
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
||||||
|
daily_raw = daily_raw.chunk(
|
||||||
|
{
|
||||||
|
"time": 120,
|
||||||
|
"latitude": -1, # Should be 337,
|
||||||
|
"longitude": -1, # Should be 3600
|
||||||
|
}
|
||||||
|
)
|
||||||
|
encoding = entropice.utils.codecs.from_ds(daily_raw, filter_existing=False)
|
||||||
|
for var in daily_raw.data_vars:
|
||||||
|
encoding[var]["chunks"] = (120, 337, 3600)
|
||||||
|
if use_shards:
|
||||||
|
encoding[var]["shards"] = (1200, 337, 3600)
|
||||||
|
print(encoding)
|
||||||
|
daily_raw.to_zarr(daily_store_rechunked, mode="w", consolidated=False, encoding=encoding)
|
||||||
|
|
||||||
|
|
||||||
|
def validate():
|
||||||
daily_store = get_era5_stores("daily")
|
daily_store = get_era5_stores("daily")
|
||||||
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
daily_raw = xr.open_zarr(daily_store, consolidated=False)
|
||||||
print_info(daily_raw, False)
|
|
||||||
daily_raw = daily_raw.chunk({
|
|
||||||
"time": 120,
|
|
||||||
"latitude": -1, # Should be 337,
|
|
||||||
"longitude": -1 # Should be 3600
|
|
||||||
})
|
|
||||||
print_info(daily_raw, False)
|
|
||||||
|
|
||||||
encoding = entropice.utils.codecs.from_ds(daily_raw)
|
|
||||||
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
|
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
|
||||||
daily_raw.to_zarr(daily_store_rechunked, mode="w", encoding=encoding, consolidated=False)
|
daily_rechunked = xr.open_zarr(daily_store_rechunked, consolidated=False)
|
||||||
|
|
||||||
|
print("\n=== Comparing Datasets ===")
|
||||||
|
|
||||||
|
# Compare dimensions
|
||||||
|
if daily_raw.sizes != daily_rechunked.sizes:
|
||||||
|
print("❌ Dimensions differ:")
|
||||||
|
print(f" Original: {daily_raw.sizes}")
|
||||||
|
print(f" Rechunked: {daily_rechunked.sizes}")
|
||||||
|
else:
|
||||||
|
print("✅ Dimensions match")
|
||||||
|
|
||||||
|
# Compare variables
|
||||||
|
raw_vars = set(daily_raw.data_vars)
|
||||||
|
rechunked_vars = set(daily_rechunked.data_vars)
|
||||||
|
if raw_vars != rechunked_vars:
|
||||||
|
print("❌ Variables differ:")
|
||||||
|
print(f" Only in original: {raw_vars - rechunked_vars}")
|
||||||
|
print(f" Only in rechunked: {rechunked_vars - raw_vars}")
|
||||||
|
else:
|
||||||
|
print("✅ Variables match")
|
||||||
|
|
||||||
|
# Compare each variable
|
||||||
|
print("\n=== Variable Comparison ===")
|
||||||
|
all_equal = True
|
||||||
|
for var in raw_vars & rechunked_vars:
|
||||||
|
raw_var = daily_raw[var]
|
||||||
|
rechunked_var = daily_rechunked[var]
|
||||||
|
|
||||||
|
if raw_var.equals(rechunked_var):
|
||||||
|
print(f"✅ {var}: Equal")
|
||||||
|
else:
|
||||||
|
all_equal = False
|
||||||
|
print(f"❌ {var}: NOT Equal")
|
||||||
|
|
||||||
|
# Check if values are equal
|
||||||
|
try:
|
||||||
|
values_equal = raw_var.values.shape == rechunked_var.values.shape
|
||||||
|
if values_equal:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
values_equal = np.allclose(raw_var.values, rechunked_var.values, equal_nan=True)
|
||||||
|
|
||||||
|
if values_equal:
|
||||||
|
print(" → Values are numerically equal (likely metadata/encoding difference)")
|
||||||
|
else:
|
||||||
|
print(" → Values differ!")
|
||||||
|
print(f" Original shape: {raw_var.values.shape}")
|
||||||
|
print(f" Rechunked shape: {rechunked_var.values.shape}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" → Error comparing values: {e}")
|
||||||
|
|
||||||
|
# Check attributes
|
||||||
|
if raw_var.attrs != rechunked_var.attrs:
|
||||||
|
print(" → Attributes differ:")
|
||||||
|
print(f" Original: {raw_var.attrs}")
|
||||||
|
print(f" Rechunked: {rechunked_var.attrs}")
|
||||||
|
|
||||||
|
# Check encoding
|
||||||
|
if raw_var.encoding != rechunked_var.encoding:
|
||||||
|
print(" → Encoding differs:")
|
||||||
|
print(f" Original: {raw_var.encoding}")
|
||||||
|
print(f" Rechunked: {rechunked_var.encoding}")
|
||||||
|
|
||||||
|
if all_equal:
|
||||||
|
print("\n✅ Validation successful: All datasets are equal.")
|
||||||
|
else:
|
||||||
|
print("\n❌ Validation failed: Datasets have differences (see above).")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
with (
|
validate()
|
||||||
dd.LocalCluster(n_workers=1, threads_per_worker=10, memory_limit="100GB") as cluster,
|
|
||||||
dd.Client(cluster) as client,
|
|
||||||
):
|
|
||||||
print(client)
|
|
||||||
print(client.dashboard_link)
|
|
||||||
rechunk()
|
|
||||||
print("Done.")
|
|
||||||
|
|
|
||||||
|
|
@ -9,11 +9,11 @@ import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pydeck as pdk
|
import pydeck as pdk
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
import xarray as xr
|
||||||
|
|
||||||
from entropice.dashboard.utils.class_ordering import get_ordered_classes
|
from entropice.dashboard.utils.class_ordering import get_ordered_classes
|
||||||
from entropice.dashboard.utils.colors import get_cmap, get_palette
|
from entropice.dashboard.utils.colors import get_cmap, get_palette
|
||||||
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
from entropice.dashboard.utils.geometry import fix_hex_geometry
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
|
||||||
from entropice.ml.training import TrainingSettings
|
from entropice.ml.training import TrainingSettings
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1154,73 +1154,36 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
|
||||||
|
|
||||||
|
|
||||||
@st.fragment
|
@st.fragment
|
||||||
def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
def render_confusion_matrix_map(
|
||||||
"""Render 3D pydeck map showing prediction results.
|
result_path: Path, settings: TrainingSettings, merged_predictions: gpd.GeoDataFrame | None = None
|
||||||
|
):
|
||||||
|
"""Render 3D pydeck map showing model performance on training data.
|
||||||
|
|
||||||
Uses true labels for elevation (height) and different shades of red for incorrect predictions
|
Displays cells from the training dataset with predictions, colored by correctness.
|
||||||
based on the predicted class.
|
Uses true labels for elevation (height) and different shades of red for incorrect predictions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
result_path: Path to the training result directory.
|
result_path: Path to the training result directory (not used, kept for compatibility).
|
||||||
settings: Settings dictionary containing grid, level, task, and target information.
|
settings: Settings dictionary containing grid, level, task, and target information.
|
||||||
|
merged_predictions: GeoDataFrame with predictions, true labels, and split info.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.subheader("🗺️ Prediction Results Map")
|
if merged_predictions is None:
|
||||||
|
st.warning("Prediction data not available. Cannot display map.")
|
||||||
# Load predicted probabilities
|
|
||||||
preds_file = result_path / "predicted_probabilities.parquet"
|
|
||||||
if not preds_file.exists():
|
|
||||||
st.warning("No predicted probabilities found for this training run.")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
preds_gdf = gpd.read_parquet(preds_file)
|
# Get grid type and task from settings
|
||||||
|
|
||||||
# Get task and target information from settings
|
|
||||||
task = settings.task
|
|
||||||
target = settings.target
|
|
||||||
grid = settings.grid
|
grid = settings.grid
|
||||||
level = settings.level
|
task = settings.task
|
||||||
|
|
||||||
# Create dataset ensemble to get true labels
|
# Use the merged predictions which already have true labels, predictions, and split info
|
||||||
# We need to load the target data to get true labels
|
merged = merged_predictions.copy()
|
||||||
try:
|
merged["is_correct"] = merged["true_class"] == merged["predicted_class"]
|
||||||
ensemble = DatasetEnsemble(
|
|
||||||
grid=grid,
|
|
||||||
level=level,
|
|
||||||
target=target,
|
|
||||||
members=[], # We don't need feature data, just target
|
|
||||||
)
|
|
||||||
training_data = ensemble.create_cat_training_dataset(task=task, device="cpu")
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Error loading training data: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get all cells from the complete dataset (not just test split)
|
|
||||||
# Use the full dataset which includes both train and test splits
|
|
||||||
all_cells = training_data.dataset.copy()
|
|
||||||
|
|
||||||
# Merge predictions with true labels
|
|
||||||
# Reset index to avoid ambiguity between index and column
|
|
||||||
labeled_gdf = all_cells.reset_index().rename(columns={"index": "cell_id"})
|
|
||||||
labeled_gdf["true_class"] = training_data.y.binned.loc[all_cells.index].to_numpy()
|
|
||||||
|
|
||||||
# Merge with predictions - use left join to keep all cells
|
|
||||||
merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="left")
|
|
||||||
merged = gpd.GeoDataFrame(merged_df, geometry="geometry", crs=labeled_gdf.crs)
|
|
||||||
|
|
||||||
# Mark which cells have predictions (test split) vs not (training split)
|
|
||||||
merged["in_test_split"] = merged["predicted_class"].notna()
|
|
||||||
|
|
||||||
# For cells without predictions (training split), use true class as predicted class for visualization
|
|
||||||
merged["predicted_class"] = merged["predicted_class"].fillna(merged["true_class"])
|
|
||||||
|
|
||||||
if len(merged) == 0:
|
if len(merged) == 0:
|
||||||
st.warning("No matching predictions found for labeled cells.")
|
st.warning("No predictions found for labeled cells.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Mark correct vs incorrect predictions (only meaningful for test split)
|
|
||||||
merged["is_correct"] = merged["true_class"] == merged["predicted_class"]
|
|
||||||
|
|
||||||
# Get ordered class labels for the task
|
# Get ordered class labels for the task
|
||||||
ordered_classes = get_ordered_classes(task)
|
ordered_classes = get_ordered_classes(task)
|
||||||
|
|
||||||
|
|
@ -1228,51 +1191,55 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
||||||
col1, col2, col3 = st.columns([2, 1, 1])
|
col1, col2, col3 = st.columns([2, 1, 1])
|
||||||
|
|
||||||
with col1:
|
with col1:
|
||||||
# Filter by prediction correctness and split
|
# Split selector (similar to confusion matrix)
|
||||||
categories = ["All", "Test Split Only", "Training Split Only", "Correct (Test)", "Incorrect (Test)"]
|
split_type = st.selectbox(
|
||||||
|
"Select Data Split",
|
||||||
selected_category = st.selectbox(
|
options=["test", "train", "all"],
|
||||||
"Filter by Category",
|
format_func=lambda x: {"test": "Test Set", "train": "Training Set (CV)", "all": "All Data"}[x],
|
||||||
options=categories,
|
help="Choose which data split to display on the map",
|
||||||
key="confusion_map_category",
|
key="prediction_map_split_select",
|
||||||
)
|
)
|
||||||
|
|
||||||
with col2:
|
with col2:
|
||||||
|
# Color scheme selector
|
||||||
|
show_only_incorrect = st.checkbox(
|
||||||
|
"Highlight Errors Only",
|
||||||
|
value=False,
|
||||||
|
help="Show only incorrect predictions in red, hide correct ones",
|
||||||
|
key="prediction_map_errors_only",
|
||||||
|
)
|
||||||
|
|
||||||
|
with col3:
|
||||||
opacity = st.slider(
|
opacity = st.slider(
|
||||||
"Opacity",
|
"Opacity",
|
||||||
min_value=0.1,
|
min_value=0.1,
|
||||||
max_value=1.0,
|
max_value=1.0,
|
||||||
value=0.7,
|
value=0.7,
|
||||||
step=0.1,
|
step=0.1,
|
||||||
key="confusion_map_opacity",
|
key="prediction_map_opacity",
|
||||||
)
|
)
|
||||||
|
|
||||||
with col3:
|
# Filter data by split
|
||||||
line_width = st.slider(
|
if split_type == "test":
|
||||||
"Line Width",
|
display_gdf = merged[merged["split"] == "test"].copy()
|
||||||
min_value=0.5,
|
split_caption = "Test Set (held-out data)"
|
||||||
max_value=3.0,
|
elif split_type == "train":
|
||||||
value=1.0,
|
display_gdf = merged[merged["split"] == "train"].copy()
|
||||||
step=0.5,
|
split_caption = "Training Set (CV data)"
|
||||||
key="confusion_map_line_width",
|
else: # "all"
|
||||||
)
|
|
||||||
|
|
||||||
# Filter data if needed
|
|
||||||
if selected_category == "Test Split Only":
|
|
||||||
display_gdf = merged[merged["in_test_split"]].copy()
|
|
||||||
elif selected_category == "Training Split Only":
|
|
||||||
display_gdf = merged[~merged["in_test_split"]].copy()
|
|
||||||
elif selected_category == "Correct (Test)":
|
|
||||||
display_gdf = merged[merged["is_correct"] & merged["in_test_split"]].copy()
|
|
||||||
elif selected_category == "Incorrect (Test)":
|
|
||||||
display_gdf = merged[~merged["is_correct"] & merged["in_test_split"]].copy()
|
|
||||||
else: # "All"
|
|
||||||
display_gdf = merged.copy()
|
display_gdf = merged.copy()
|
||||||
|
split_caption = "All Available Data"
|
||||||
|
|
||||||
|
# Optionally filter to show only incorrect predictions
|
||||||
|
if show_only_incorrect:
|
||||||
|
display_gdf = display_gdf[~display_gdf["is_correct"]].copy()
|
||||||
|
|
||||||
if len(display_gdf) == 0:
|
if len(display_gdf) == 0:
|
||||||
st.warning(f"No cells found for category: {selected_category}")
|
st.warning(f"No cells found for {split_caption}.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
st.caption(f"📍 Showing {len(display_gdf)} cells from {split_caption}")
|
||||||
|
|
||||||
# Convert to WGS84 for pydeck
|
# Convert to WGS84 for pydeck
|
||||||
display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
|
display_gdf_wgs84 = display_gdf.to_crs("EPSG:4326")
|
||||||
|
|
||||||
|
|
@ -1303,12 +1270,12 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
||||||
|
|
||||||
display_gdf_wgs84["fill_color"] = display_gdf_wgs84.apply(get_color, axis=1)
|
display_gdf_wgs84["fill_color"] = display_gdf_wgs84.apply(get_color, axis=1)
|
||||||
|
|
||||||
# Add line color based on split: blue for test split, orange for training split
|
# Add line color based on split: blue for test, orange for train
|
||||||
def get_line_color(row):
|
def get_line_color(row):
|
||||||
if row["in_test_split"]:
|
if row["split"] == "test":
|
||||||
return [52, 152, 219] # Blue for test split
|
return [52, 152, 219] # Blue for test split
|
||||||
else:
|
else:
|
||||||
return [230, 126, 34] # Orange for training split
|
return [230, 126, 34] # Orange for train split
|
||||||
|
|
||||||
display_gdf_wgs84["line_color"] = display_gdf_wgs84.apply(get_line_color, axis=1)
|
display_gdf_wgs84["line_color"] = display_gdf_wgs84.apply(get_line_color, axis=1)
|
||||||
|
|
||||||
|
|
@ -1329,18 +1296,15 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
||||||
geojson_data = []
|
geojson_data = []
|
||||||
for _, row in display_gdf_wgs84.iterrows():
|
for _, row in display_gdf_wgs84.iterrows():
|
||||||
# Determine split and status for tooltip
|
# Determine split and status for tooltip
|
||||||
split_name = "Test Split" if row["in_test_split"] else "Training Split"
|
split_name = "Test" if row["split"] == "test" else "Training (CV)"
|
||||||
if row["in_test_split"]:
|
status = "✓ Correct" if row["is_correct"] else "✗ Incorrect"
|
||||||
status = "✓ Correct" if row["is_correct"] else "✗ Incorrect"
|
|
||||||
else:
|
|
||||||
status = "(No prediction - training data)"
|
|
||||||
|
|
||||||
feature = {
|
feature = {
|
||||||
"type": "Feature",
|
"type": "Feature",
|
||||||
"geometry": row["geometry"].__geo_interface__,
|
"geometry": row["geometry"].__geo_interface__,
|
||||||
"properties": {
|
"properties": {
|
||||||
"true_class": str(row["true_class"]),
|
"true_label": str(row["true_class"]),
|
||||||
"predicted_class": str(row["predicted_class"]) if row["in_test_split"] else "N/A",
|
"predicted_label": str(row["predicted_class"]),
|
||||||
"is_correct": bool(row["is_correct"]),
|
"is_correct": bool(row["is_correct"]),
|
||||||
"split": split_name,
|
"split": split_name,
|
||||||
"status": status,
|
"status": status,
|
||||||
|
|
@ -1362,7 +1326,7 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
||||||
wireframe=False,
|
wireframe=False,
|
||||||
get_fill_color="properties.fill_color",
|
get_fill_color="properties.fill_color",
|
||||||
get_line_color="properties.line_color",
|
get_line_color="properties.line_color",
|
||||||
line_width_min_pixels=line_width,
|
line_width_min_pixels=2,
|
||||||
get_elevation="properties.elevation",
|
get_elevation="properties.elevation",
|
||||||
elevation_scale=500000,
|
elevation_scale=500000,
|
||||||
pickable=True,
|
pickable=True,
|
||||||
|
|
@ -1376,10 +1340,10 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
||||||
layers=[layer],
|
layers=[layer],
|
||||||
initial_view_state=view_state,
|
initial_view_state=view_state,
|
||||||
tooltip={
|
tooltip={
|
||||||
"html": "<b>Split:</b> {split}<br/>"
|
"html": "<b>Status:</b> {status}<br/>"
|
||||||
"<b>True Label:</b> {true_class}<br/>"
|
"<b>True Label:</b> {true_label}<br/>"
|
||||||
"<b>Predicted Label:</b> {predicted_class}<br/>"
|
"<b>Predicted Label:</b> {predicted_label}<br/>"
|
||||||
"<b>Status:</b> {status}",
|
"<b>Split:</b> {split}",
|
||||||
"style": {"backgroundColor": "steelblue", "color": "white"},
|
"style": {"backgroundColor": "steelblue", "color": "white"},
|
||||||
},
|
},
|
||||||
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
|
||||||
|
|
@ -1388,102 +1352,91 @@ def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
|
||||||
# Render the map
|
# Render the map
|
||||||
st.pydeck_chart(deck)
|
st.pydeck_chart(deck)
|
||||||
|
|
||||||
# Show statistics
|
# Show statistics for displayed data
|
||||||
col1, col2, col3, col4 = st.columns(4)
|
col1, col2, col3 = st.columns(3)
|
||||||
|
|
||||||
with col1:
|
with col1:
|
||||||
st.metric("Total Labeled Cells", len(merged))
|
st.metric("Cells Displayed", len(display_gdf))
|
||||||
|
|
||||||
with col2:
|
with col2:
|
||||||
test_count = len(merged[merged["in_test_split"]])
|
correct = len(display_gdf[display_gdf["is_correct"]])
|
||||||
st.metric("Test Split", test_count)
|
st.metric("Correct Predictions", correct)
|
||||||
|
|
||||||
with col3:
|
with col3:
|
||||||
train_count = len(merged[~merged["in_test_split"]])
|
if len(display_gdf) > 0:
|
||||||
st.metric("Training Split", train_count)
|
accuracy = correct / len(display_gdf)
|
||||||
|
st.metric("Accuracy", f"{accuracy:.2%}")
|
||||||
with col4:
|
|
||||||
test_cells = merged[merged["in_test_split"]]
|
|
||||||
if len(test_cells) > 0:
|
|
||||||
correct = len(test_cells[test_cells["is_correct"]])
|
|
||||||
accuracy = correct / len(test_cells)
|
|
||||||
st.metric("Test Accuracy", f"{accuracy:.2%}")
|
|
||||||
else:
|
else:
|
||||||
st.metric("Test Accuracy", "N/A")
|
st.metric("Accuracy", "N/A")
|
||||||
|
|
||||||
# Add legend
|
# Add legend
|
||||||
with st.expander("Legend", expanded=True):
|
with st.expander("Legend", expanded=True):
|
||||||
# Split indicators (border colors)
|
st.markdown("**Fill Color (Prediction Correctness):**")
|
||||||
st.markdown("**Data Split (Border Color):**")
|
|
||||||
|
|
||||||
test_count = len(merged[merged["in_test_split"]])
|
# Correct predictions
|
||||||
train_count = len(merged[~merged["in_test_split"]])
|
correct_count = len(display_gdf[display_gdf["is_correct"]])
|
||||||
|
incorrect_count = len(display_gdf[~display_gdf["is_correct"]])
|
||||||
st.markdown(
|
|
||||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
|
||||||
f'<div style="width: 20px; height: 20px; background-color: #333; '
|
|
||||||
f'border: 2px solid rgb(52, 152, 219); margin-right: 8px; flex-shrink: 0;"></div>'
|
|
||||||
f"<span><b>Test Split</b> ({test_count} cells, {test_count / len(merged) * 100:.1f}%)</span></div>",
|
|
||||||
unsafe_allow_html=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
st.markdown(
|
|
||||||
f'<div style="display: flex; align-items: center; margin-bottom: 12px;">'
|
|
||||||
f'<div style="width: 20px; height: 20px; background-color: #333; '
|
|
||||||
f'border: 2px solid rgb(230, 126, 34); margin-right: 8px; flex-shrink: 0;"></div>'
|
|
||||||
f"<span><b>Training Split</b> ({train_count} cells, {train_count / len(merged) * 100:.1f}%)</span></div>",
|
|
||||||
unsafe_allow_html=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
st.markdown("**Fill Color (Prediction Results):**")
|
|
||||||
|
|
||||||
# Correct predictions (test split only)
|
|
||||||
test_cells = merged[merged["in_test_split"]]
|
|
||||||
correct = len(test_cells[test_cells["is_correct"]]) if len(test_cells) > 0 else 0
|
|
||||||
incorrect = len(test_cells[~test_cells["is_correct"]]) if len(test_cells) > 0 else 0
|
|
||||||
|
|
||||||
st.markdown(
|
st.markdown(
|
||||||
f'<div style="display: flex; align-items: center; margin-bottom: 8px;">'
|
f'<div style="display: flex; align-items: center; margin-bottom: 8px;">'
|
||||||
f'<div style="width: 20px; height: 20px; background-color: rgb(46, 204, 113); '
|
f'<div style="width: 20px; height: 20px; background-color: rgb(46, 204, 113); '
|
||||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||||
f"<span><b>Correct Predictions (Test)</b> ({correct} cells, {correct / len(test_cells) * 100 if len(test_cells) > 0 else 0:.1f}%)</span></div>",
|
f"<span><b>Correct Predictions</b> ({correct_count} cells, "
|
||||||
|
f"{correct_count / len(display_gdf) * 100 if len(display_gdf) > 0 else 0:.1f}%)</span></div>",
|
||||||
unsafe_allow_html=True,
|
unsafe_allow_html=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Incorrect predictions by predicted class (shades of red)
|
# Incorrect predictions by predicted class (shades of red)
|
||||||
st.markdown(
|
if incorrect_count > 0:
|
||||||
f"<b>Incorrect Predictions by Predicted Class (Test)</b> ({incorrect} cells):", unsafe_allow_html=True
|
st.markdown(
|
||||||
)
|
f"<b>Incorrect Predictions by Predicted Class</b> ({incorrect_count} cells):", unsafe_allow_html=True
|
||||||
|
)
|
||||||
|
|
||||||
for class_idx, class_label in enumerate(ordered_classes):
|
for class_idx, class_label in enumerate(ordered_classes):
|
||||||
# Get count of incorrect predictions for this predicted class (test split only)
|
# Get count of incorrect predictions for this predicted class
|
||||||
count = len(test_cells[(~test_cells["is_correct"]) & (test_cells["predicted_class"] == class_label)])
|
count = len(display_gdf[(~display_gdf["is_correct"]) & (display_gdf["predicted_class"] == class_label)])
|
||||||
if count > 0:
|
if count > 0:
|
||||||
# Get color for this predicted class
|
# Get color for this predicted class
|
||||||
color_value = red_cmap(class_idx / max(n_classes - 1, 1))
|
color_value = red_cmap(class_idx / max(n_classes - 1, 1))
|
||||||
rgb = [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)]
|
rgb = [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)]
|
||||||
|
|
||||||
percentage = count / incorrect * 100 if incorrect > 0 else 0
|
percentage = count / incorrect_count * 100
|
||||||
|
|
||||||
st.markdown(
|
st.markdown(
|
||||||
f'<div style="display: flex; align-items: center; margin-bottom: 4px; margin-left: 20px;">'
|
f'<div style="display: flex; align-items: center; margin-bottom: 4px; margin-left: 20px;">'
|
||||||
f'<div style="width: 20px; height: 20px; background-color: rgb({rgb[0]}, {rgb[1]}, {rgb[2]}); '
|
f'<div style="width: 20px; height: 20px; background-color: rgb({rgb[0]}, {rgb[1]}, {rgb[2]}); '
|
||||||
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
|
||||||
f"<span>Predicted as <i>{class_label}</i>: {count} ({percentage:.1f}%)</span></div>",
|
f"<span>Predicted as <i>{class_label}</i>: {count} ({percentage:.1f}%)</span></div>",
|
||||||
unsafe_allow_html=True,
|
unsafe_allow_html=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note about training split
|
|
||||||
st.markdown(
|
|
||||||
f'<div style="margin-top: 8px; font-style: italic; color: #888;">'
|
|
||||||
f"Note: Training split cells ({train_count}) are shown with their true labels (green fill) "
|
|
||||||
f"since predictions are only available for the test split.</div>",
|
|
||||||
unsafe_allow_html=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
st.markdown("**Elevation (3D):**")
|
st.markdown("**Border Color (Data Split):**")
|
||||||
|
|
||||||
|
# Count by split in displayed data
|
||||||
|
test_in_display = len(display_gdf[display_gdf["split"] == "test"])
|
||||||
|
train_in_display = len(display_gdf[display_gdf["split"] == "train"])
|
||||||
|
|
||||||
|
if test_in_display > 0:
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||||
|
f'<div style="width: 20px; height: 20px; background-color: #333; '
|
||||||
|
f'border: 2px solid rgb(52, 152, 219); margin-right: 8px; flex-shrink: 0;"></div>'
|
||||||
|
f"<span><b>Test Split</b> ({test_in_display} cells)</span></div>",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_in_display > 0:
|
||||||
|
st.markdown(
|
||||||
|
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
|
||||||
|
f'<div style="width: 20px; height: 20px; background-color: #333; '
|
||||||
|
f'border: 2px solid rgb(230, 126, 34); margin-right: 8px; flex-shrink: 0;"></div>'
|
||||||
|
f"<span><b>Training Split</b> ({train_in_display} cells)</span></div>",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
st.markdown("**Elevation (3D Height):**")
|
||||||
|
|
||||||
# Show elevation mapping for each true class
|
# Show elevation mapping for each true class
|
||||||
st.markdown("Height represents the <b>true label</b>:", unsafe_allow_html=True)
|
st.markdown("Height represents the <b>true label</b>:", unsafe_allow_html=True)
|
||||||
|
|
@ -1514,13 +1467,23 @@ def render_confusion_matrix_heatmap(confusion_matrix: "xr.DataArray", task: str)
|
||||||
true_labels = confusion_matrix.coords["true_label"].values
|
true_labels = confusion_matrix.coords["true_label"].values
|
||||||
pred_labels = confusion_matrix.coords["predicted_label"].values
|
pred_labels = confusion_matrix.coords["predicted_label"].values
|
||||||
|
|
||||||
# For binary classification, map 0/1 to No-RTS/RTS
|
# Check if labels are already strings (from predictions) or numeric (from stored confusion matrices)
|
||||||
if task == "binary":
|
first_true_label = true_labels[0]
|
||||||
|
is_string_labels = isinstance(first_true_label, str) or (
|
||||||
|
hasattr(first_true_label, "dtype") and first_true_label.dtype.kind in ("U", "O")
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_string_labels:
|
||||||
|
# Labels are already string labels, use them directly
|
||||||
|
true_labels_str = [str(label) for label in true_labels]
|
||||||
|
pred_labels_str = [str(label) for label in pred_labels]
|
||||||
|
elif task == "binary":
|
||||||
|
# Numeric binary labels - map 0/1 to No-RTS/RTS
|
||||||
label_map = {0: "No-RTS", 1: "RTS"}
|
label_map = {0: "No-RTS", 1: "RTS"}
|
||||||
true_labels_str = [label_map.get(int(label), str(label)) for label in true_labels]
|
true_labels_str = [label_map.get(int(label), str(label)) for label in true_labels]
|
||||||
pred_labels_str = [label_map.get(int(label), str(label)) for label in pred_labels]
|
pred_labels_str = [label_map.get(int(label), str(label)) for label in pred_labels]
|
||||||
else:
|
else:
|
||||||
# For multiclass, use numeric labels as is
|
# Numeric multiclass labels - use as is
|
||||||
true_labels_str = [str(label) for label in true_labels]
|
true_labels_str = [str(label) for label in true_labels]
|
||||||
pred_labels_str = [str(label) for label in pred_labels]
|
pred_labels_str = [str(label) for label in pred_labels]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,6 @@ def get_members_from_settings(settings) -> list[L2SourceDataset]:
|
||||||
return settings.members
|
return settings.members
|
||||||
|
|
||||||
|
|
||||||
@st.fragment
|
|
||||||
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
|
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
|
||||||
"""Render sidebar for training run selection.
|
"""Render sidebar for training run selection.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@
|
||||||
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
import geopandas as gpd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
import xarray as xr
|
||||||
from stopuhr import stopwatch
|
from stopuhr import stopwatch
|
||||||
|
|
||||||
from entropice.dashboard.plots.hyperparameter_analysis import (
|
from entropice.dashboard.plots.hyperparameter_analysis import (
|
||||||
|
|
@ -22,6 +24,127 @@ from entropice.dashboard.utils.stats import CVResultsStatistics
|
||||||
from entropice.utils.types import GridConfig
|
from entropice.utils.types import GridConfig
|
||||||
|
|
||||||
|
|
||||||
|
def load_predictions_with_labels(selected_result: TrainingResult) -> gpd.GeoDataFrame | None:
|
||||||
|
"""Load predictions and merge with training data to get true labels and split info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selected_result: The selected TrainingResult object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeoDataFrame with predictions, true labels, and split information, or None if unavailable.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
from entropice.ml.dataset import DatasetEnsemble, bin_values, taskcol
|
||||||
|
|
||||||
|
# Load predictions
|
||||||
|
preds_gdf = selected_result.load_predictions()
|
||||||
|
if preds_gdf is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create a minimal dataset ensemble to access target data
|
||||||
|
settings = selected_result.settings
|
||||||
|
dataset_ensemble = DatasetEnsemble(
|
||||||
|
grid=settings.grid,
|
||||||
|
level=settings.level,
|
||||||
|
target=settings.target,
|
||||||
|
members=[], # No feature data needed, just targets
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load target dataset (just labels, no features)
|
||||||
|
with st.spinner("Loading target labels..."):
|
||||||
|
targets = dataset_ensemble._read_target()
|
||||||
|
|
||||||
|
# Get coverage and task columns
|
||||||
|
task_col = taskcol[settings.task][settings.target]
|
||||||
|
|
||||||
|
# Filter for valid labels (same as in _cat_and_split)
|
||||||
|
valid_labels = targets[task_col].notna()
|
||||||
|
filtered_targets = targets.loc[valid_labels].copy()
|
||||||
|
|
||||||
|
# Apply binning to get class labels (same logic as _cat_and_split)
|
||||||
|
if settings.task == "binary":
|
||||||
|
binned = filtered_targets[task_col].map({False: "No RTS", True: "RTS"}).astype("category")
|
||||||
|
elif settings.task == "count":
|
||||||
|
binned = bin_values(filtered_targets[task_col].astype(int), task=settings.task)
|
||||||
|
elif settings.task == "density":
|
||||||
|
binned = bin_values(filtered_targets[task_col], task=settings.task)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid task: {settings.task}")
|
||||||
|
|
||||||
|
filtered_targets["true_class"] = binned.to_numpy()
|
||||||
|
|
||||||
|
# Recreate the train/test split deterministically (same random_state=42 as in _cat_and_split)
|
||||||
|
_train_idx, test_idx = train_test_split(
|
||||||
|
filtered_targets.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True
|
||||||
|
)
|
||||||
|
filtered_targets["split"] = "train"
|
||||||
|
filtered_targets.loc[test_idx, "split"] = "test"
|
||||||
|
filtered_targets["split"] = filtered_targets["split"].astype("category")
|
||||||
|
|
||||||
|
# Ensure cell_id is available as a column for merging
|
||||||
|
# Check if cell_id already exists, otherwise use the index
|
||||||
|
if "cell_id" not in filtered_targets.columns:
|
||||||
|
filtered_targets = filtered_targets.reset_index().rename(columns={"index": "cell_id"})
|
||||||
|
|
||||||
|
# Merge predictions with labels (inner join to keep only cells with predictions)
|
||||||
|
merged = filtered_targets.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner")
|
||||||
|
merged_gdf = gpd.GeoDataFrame(merged, geometry="geometry", crs=targets.crs)
|
||||||
|
|
||||||
|
return merged_gdf
|
||||||
|
|
||||||
|
|
||||||
|
def compute_confusion_matrix_from_merged_data(
|
||||||
|
merged_data: gpd.GeoDataFrame,
|
||||||
|
split_type: str,
|
||||||
|
label_names: list[str],
|
||||||
|
) -> xr.DataArray | None:
|
||||||
|
"""Compute confusion matrix from merged predictions and labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
merged_data: GeoDataFrame with 'true_class', 'predicted_class', and 'split' columns.
|
||||||
|
split_type: One of 'test', 'train', or 'all'.
|
||||||
|
label_names: List of class label names in order.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xarray.DataArray with confusion matrix or None if data unavailable.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from sklearn.metrics import confusion_matrix
|
||||||
|
|
||||||
|
# Filter by split type
|
||||||
|
if split_type == "train":
|
||||||
|
data = merged_data[merged_data["split"] == "train"]
|
||||||
|
elif split_type == "test":
|
||||||
|
data = merged_data[merged_data["split"] == "test"]
|
||||||
|
elif split_type == "all":
|
||||||
|
data = merged_data
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid split_type: {split_type}")
|
||||||
|
|
||||||
|
if len(data) == 0:
|
||||||
|
st.warning(f"No data available for {split_type} split.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get true and predicted labels
|
||||||
|
y_true = data["true_class"].to_numpy()
|
||||||
|
y_pred = data["predicted_class"].to_numpy()
|
||||||
|
|
||||||
|
# Compute confusion matrix
|
||||||
|
cm = confusion_matrix(y_true, y_pred, labels=label_names)
|
||||||
|
|
||||||
|
# Create xarray DataArray
|
||||||
|
cm_xr = xr.DataArray(
|
||||||
|
cm,
|
||||||
|
dims=["true_label", "predicted_label"],
|
||||||
|
coords={"true_label": label_names, "predicted_label": label_names},
|
||||||
|
name="confusion_matrix",
|
||||||
|
)
|
||||||
|
|
||||||
|
return cm_xr
|
||||||
|
|
||||||
|
|
||||||
def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str, int]:
|
def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str, int]:
|
||||||
"""Render sidebar for training run and analysis settings selection.
|
"""Render sidebar for training run and analysis settings selection.
|
||||||
|
|
||||||
|
|
@ -233,21 +356,56 @@ def render_cv_statistics_section(selected_result, selected_metric):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def render_confusion_matrix_section(selected_result: TrainingResult):
|
@st.fragment
|
||||||
|
def render_confusion_matrix_section(selected_result: TrainingResult, merged_predictions: gpd.GeoDataFrame | None):
|
||||||
"""Render confusion matrix visualization and analysis.
|
"""Render confusion matrix visualization and analysis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
selected_result: The selected TrainingResult object.
|
selected_result: The selected TrainingResult object.
|
||||||
|
merged_predictions: GeoDataFrame with predictions merged with true labels and split info.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
st.header("🎲 Confusion Matrix")
|
st.header("🎲 Confusion Matrix")
|
||||||
st.caption("Detailed breakdown of predictions on the test set")
|
st.caption("Detailed breakdown of predictions")
|
||||||
|
|
||||||
if selected_result.confusion_matrix is None:
|
# Add selector for confusion matrix type
|
||||||
st.warning("No confusion matrix available for this training run.")
|
cm_type = st.selectbox(
|
||||||
return
|
"Select Data Split",
|
||||||
|
options=["test", "train", "all"],
|
||||||
|
format_func=lambda x: {"test": "Test Set", "train": "CV Set (Train Split)", "all": "All Available Data"}[x],
|
||||||
|
help="Choose which data split to display the confusion matrix for",
|
||||||
|
key="cm_split_select",
|
||||||
|
)
|
||||||
|
|
||||||
render_confusion_matrix_heatmap(selected_result.confusion_matrix, selected_result.settings.task)
|
# Get label names from settings
|
||||||
|
label_names = selected_result.settings.classes
|
||||||
|
|
||||||
|
# Compute or load confusion matrix based on selection
|
||||||
|
if cm_type == "test":
|
||||||
|
if selected_result.confusion_matrix is None:
|
||||||
|
st.warning("No confusion matrix available for the test set.")
|
||||||
|
return
|
||||||
|
cm = selected_result.confusion_matrix
|
||||||
|
st.info("📊 Showing confusion matrix for the **Test Set** (held-out data, never used during training)")
|
||||||
|
else:
|
||||||
|
if merged_predictions is None:
|
||||||
|
st.warning("Predictions data not available. Cannot compute confusion matrix.")
|
||||||
|
return
|
||||||
|
|
||||||
|
with st.spinner(f"Computing confusion matrix for {cm_type} split..."):
|
||||||
|
cm = compute_confusion_matrix_from_merged_data(merged_predictions, cm_type, label_names)
|
||||||
|
if cm is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if cm_type == "train":
|
||||||
|
st.info(
|
||||||
|
"📊 Showing confusion matrix for the **CV Set (Train Split)** "
|
||||||
|
"(data used during hyperparameter search cross-validation)"
|
||||||
|
)
|
||||||
|
else: # all
|
||||||
|
st.info("📊 Showing confusion matrix for **All Available Data** (combined train and test splits)")
|
||||||
|
|
||||||
|
render_confusion_matrix_heatmap(cm, selected_result.settings.task)
|
||||||
|
|
||||||
|
|
||||||
def render_parameter_space_section(selected_result, selected_metric):
|
def render_parameter_space_section(selected_result, selected_metric):
|
||||||
|
|
@ -374,6 +532,9 @@ def render_training_analysis_page():
|
||||||
return
|
return
|
||||||
selected_result, selected_metric, refit_metric, top_n = selection_result
|
selected_result, selected_metric, refit_metric, top_n = selection_result
|
||||||
|
|
||||||
|
# Load predictions with labels once (used by confusion matrix and map)
|
||||||
|
merged_predictions = load_predictions_with_labels(selected_result)
|
||||||
|
|
||||||
# Main content area
|
# Main content area
|
||||||
results = selected_result.results
|
results = selected_result.results
|
||||||
settings = selected_result.settings
|
settings = selected_result.settings
|
||||||
|
|
@ -389,7 +550,7 @@ def render_training_analysis_page():
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# Confusion Matrix Section
|
# Confusion Matrix Section
|
||||||
render_confusion_matrix_section(selected_result)
|
render_confusion_matrix_section(selected_result, merged_predictions)
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
|
|
@ -400,9 +561,10 @@ def render_training_analysis_page():
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# Confusion Matrix Map Section
|
# Prediction Analysis Map Section
|
||||||
st.header("🗺️ Prediction Results Map")
|
st.header("🗺️ Model Performance Map")
|
||||||
render_confusion_matrix_map(selected_result.path, settings)
|
st.caption("Interactive 3D map showing prediction correctness across the training dataset")
|
||||||
|
render_confusion_matrix_map(selected_result.path, settings, merged_predictions)
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -734,14 +734,14 @@ def spatial_agg(
|
||||||
3: _Aggregations.common(),
|
3: _Aggregations.common(),
|
||||||
4: _Aggregations.common(),
|
4: _Aggregations.common(),
|
||||||
5: _Aggregations(mean=True),
|
5: _Aggregations(mean=True),
|
||||||
6: "interpolate",
|
6: "interpolate", # nearest neighbor interpolation
|
||||||
},
|
},
|
||||||
"healpix": {
|
"healpix": {
|
||||||
6: _Aggregations.common(),
|
6: _Aggregations.common(),
|
||||||
7: _Aggregations.common(),
|
7: _Aggregations.common(),
|
||||||
8: _Aggregations(mean=True),
|
8: _Aggregations(mean=True),
|
||||||
9: _Aggregations(mean=True),
|
9: _Aggregations(mean=True),
|
||||||
10: "interpolate",
|
10: "interpolate", # nearest neighbor interpolation
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
aggregations = aggregations_by_gridlevel[grid][level]
|
aggregations = aggregations_by_gridlevel[grid][level]
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ import cupy as cp
|
||||||
import cyclopts
|
import cyclopts
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import toml
|
import toml
|
||||||
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from array_api_compat import get_namespace
|
|
||||||
from cuml.ensemble import RandomForestClassifier
|
from cuml.ensemble import RandomForestClassifier
|
||||||
from cuml.neighbors import KNeighborsClassifier
|
from cuml.neighbors import KNeighborsClassifier
|
||||||
from entropy import ESPAClassifier
|
from entropy import ESPAClassifier
|
||||||
|
|
@ -233,10 +233,9 @@ def random_cv(
|
||||||
# Compute predictions on the test set
|
# Compute predictions on the test set
|
||||||
y_pred = best_estimator.predict(training_data.X.test)
|
y_pred = best_estimator.predict(training_data.X.test)
|
||||||
labels = list(range(len(training_data.y.labels)))
|
labels = list(range(len(training_data.y.labels)))
|
||||||
xp = get_namespace(y_test)
|
y_test = torch.asarray(y_test, device="cuda")
|
||||||
y_test = xp.as_array(y_test)
|
y_pred = torch.asarray(y_pred, device="cuda")
|
||||||
y_pred = xp.as_array(y_pred)
|
labels = torch.asarray(labels, device="cuda")
|
||||||
labels = xp.as_array(labels)
|
|
||||||
|
|
||||||
test_metrics = {metric: _metric_functions[metric](y_test, y_pred) for metric in metrics}
|
test_metrics = {metric: _metric_functions[metric](y_test, y_pred) for metric in metrics}
|
||||||
|
|
||||||
|
|
@ -244,7 +243,7 @@ def random_cv(
|
||||||
cm = confusion_matrix(y_test, y_pred, labels=labels)
|
cm = confusion_matrix(y_test, y_pred, labels=labels)
|
||||||
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
|
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
|
||||||
cm = xr.DataArray(
|
cm = xr.DataArray(
|
||||||
xp.as_numpy(cm),
|
cm.cpu().numpy(),
|
||||||
dims=["true_label", "predicted_label"],
|
dims=["true_label", "predicted_label"],
|
||||||
coords={"true_label": label_names, "predicted_label": label_names},
|
coords={"true_label": label_names, "predicted_label": label_names},
|
||||||
name="confusion_matrix",
|
name="confusion_matrix",
|
||||||
|
|
|
||||||
|
|
@ -368,7 +368,7 @@ def _init_worker(r: xr.Dataset | None):
|
||||||
def _align_partition(
|
def _align_partition(
|
||||||
grid_partition_gdf: gpd.GeoDataFrame,
|
grid_partition_gdf: gpd.GeoDataFrame,
|
||||||
raster: xr.Dataset | Callable[[], xr.Dataset] | None,
|
raster: xr.Dataset | Callable[[], xr.Dataset] | None,
|
||||||
aggregations: _Aggregations | None, # None -> Interpolation
|
aggregations: _Aggregations | str, # str -> Interpolation method
|
||||||
pxbuffer: int,
|
pxbuffer: int,
|
||||||
):
|
):
|
||||||
# ? This function is expected to run inside a worker process
|
# ? This function is expected to run inside a worker process
|
||||||
|
|
@ -441,7 +441,7 @@ def _align_partition(
|
||||||
)
|
)
|
||||||
memprof.log_memory("After reading partial raster", log=False)
|
memprof.log_memory("After reading partial raster", log=False)
|
||||||
|
|
||||||
if aggregations is None:
|
if isinstance(aggregations, str):
|
||||||
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
|
cell_ids = grids.convert_cell_ids(grid_partition_gdf)
|
||||||
if grid_partition_gdf.crs.to_epsg() == 4326:
|
if grid_partition_gdf.crs.to_epsg() == 4326:
|
||||||
centroids = grid_partition_gdf.geometry.apply(antimeridian.fix_shape).apply(antimeridian.centroid)
|
centroids = grid_partition_gdf.geometry.apply(antimeridian.fix_shape).apply(antimeridian.centroid)
|
||||||
|
|
@ -453,14 +453,26 @@ def _align_partition(
|
||||||
cy = centroids.y
|
cy = centroids.y
|
||||||
interp_x = xr.DataArray(cx, dims=["cell_ids"], coords={"cell_ids": cell_ids})
|
interp_x = xr.DataArray(cx, dims=["cell_ids"], coords={"cell_ids": cell_ids})
|
||||||
interp_y = xr.DataArray(cy, dims=["cell_ids"], coords={"cell_ids": cell_ids})
|
interp_y = xr.DataArray(cy, dims=["cell_ids"], coords={"cell_ids": cell_ids})
|
||||||
interp_coords = (
|
ydim = "latitude" if "latitude" in raster.dims else "y"
|
||||||
{"latitude": interp_y, "longitude": interp_x}
|
xdim = "longitude" if "longitude" in raster.dims else "x"
|
||||||
if "latitude" in raster.dims and "longitude" in raster.dims
|
interp_coords = {ydim: interp_y, xdim: interp_x}
|
||||||
else {"y": interp_y, "x": interp_x}
|
|
||||||
)
|
|
||||||
# ?: Cubic does not work with NaNs in xarray interp
|
# ?: Cubic does not work with NaNs in xarray interp
|
||||||
with stopwatch("Interpolating data to grid centroids", log=False):
|
with stopwatch("Interpolating data to grid centroids", log=False):
|
||||||
ongrid = partial_raster.interp(interp_coords, method="linear", kwargs={"fill_value": np.nan})
|
# Fill the nan
|
||||||
|
y_is_rev = partial_raster.indexes[ydim].is_monotonic_decreasing
|
||||||
|
if y_is_rev:
|
||||||
|
partial_raster = partial_raster.sortby(ydim)
|
||||||
|
partial_raster = partial_raster.interpolate_na(
|
||||||
|
dim=ydim,
|
||||||
|
method=aggregations,
|
||||||
|
).interpolate_na(
|
||||||
|
dim=xdim,
|
||||||
|
method=aggregations,
|
||||||
|
)
|
||||||
|
if y_is_rev:
|
||||||
|
partial_raster = partial_raster.sortby(ydim)
|
||||||
|
ongrid = partial_raster.interp(interp_coords, method=aggregations, kwargs={"fill_value": np.nan})
|
||||||
memprof.log_memory("After interpolating data", log=False)
|
memprof.log_memory("After interpolating data", log=False)
|
||||||
else:
|
else:
|
||||||
others_shape = tuple(
|
others_shape = tuple(
|
||||||
|
|
@ -519,7 +531,7 @@ def _align_partition(
|
||||||
def _align_data(
|
def _align_data(
|
||||||
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
||||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||||
aggregations: _Aggregations | None,
|
aggregations: _Aggregations | str,
|
||||||
n_partitions: int | None,
|
n_partitions: int | None,
|
||||||
concurrent_partitions: int,
|
concurrent_partitions: int,
|
||||||
pxbuffer: int,
|
pxbuffer: int,
|
||||||
|
|
@ -620,7 +632,7 @@ def _align_data(
|
||||||
def aggregate_raster_into_grid(
|
def aggregate_raster_into_grid(
|
||||||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||||
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
||||||
aggregations: _Aggregations | Literal["interpolate"],
|
aggregations: _Aggregations | Literal["nearest", "linear", "cubic", "interpolate"],
|
||||||
grid: Grid,
|
grid: Grid,
|
||||||
level: int,
|
level: int,
|
||||||
n_partitions: int | None = 20,
|
n_partitions: int | None = 20,
|
||||||
|
|
@ -634,7 +646,11 @@ def aggregate_raster_into_grid(
|
||||||
grid_gdf (gpd.GeoDataFrame | list[gpd.GeoDataFrame]): The grid to aggregate into.
|
grid_gdf (gpd.GeoDataFrame | list[gpd.GeoDataFrame]): The grid to aggregate into.
|
||||||
If a list of GeoDataFrames is provided, each will be processed as a separate partition.
|
If a list of GeoDataFrames is provided, each will be processed as a separate partition.
|
||||||
No further partitioning will be done and the n_partitions argument will be ignored.
|
No further partitioning will be done and the n_partitions argument will be ignored.
|
||||||
aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform.
|
aggregations (_Aggregations | Literal[["nearest", "linear", "cubic", "interpolate"]):
|
||||||
|
The aggregations to perform.
|
||||||
|
If a string is provided, interpolation will be used with the specified method.
|
||||||
|
If "interpolate" is provided, the nearest neighbor interpolation will be used.
|
||||||
|
Supported methods are "nearest", "linear", and "cubic".
|
||||||
grid (Grid): The type of grid to use.
|
grid (Grid): The type of grid to use.
|
||||||
level (int): The level of the grid.
|
level (int): The level of the grid.
|
||||||
n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20.
|
n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20.
|
||||||
|
|
@ -649,7 +665,7 @@ def aggregate_raster_into_grid(
|
||||||
ongrid = _align_data(
|
ongrid = _align_data(
|
||||||
grid_gdf,
|
grid_gdf,
|
||||||
raster,
|
raster,
|
||||||
aggregations if aggregations != "interpolate" else None,
|
aggregations if aggregations != "interpolate" else "nearest",
|
||||||
n_partitions=n_partitions,
|
n_partitions=n_partitions,
|
||||||
concurrent_partitions=concurrent_partitions,
|
concurrent_partitions=concurrent_partitions,
|
||||||
pxbuffer=pxbuffer,
|
pxbuffer=pxbuffer,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue