Enhance training analysis page with test metrics and confusion matrix
- Added a section to display test metrics for model performance on the held-out test set. - Implemented confusion matrix visualization to analyze prediction breakdown. - Refactored sidebar settings to streamline metric selection and improve user experience. - Updated cross-validation statistics to compare CV performance with test metrics. - Enhanced DatasetEnsemble methods to handle empty data scenarios gracefully. - Introduced debug scripts to assist in identifying feature mismatches and validating dataset preparation. - Added comprehensive tests for DatasetEnsemble to ensure feature consistency and correct behavior across various scenarios.
This commit is contained in:
parent
4fecac535c
commit
c92e856c55
23 changed files with 1845 additions and 484 deletions
33
tests/debug_arcticdem_batch.py
Normal file
33
tests/debug_arcticdem_batch.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Debug script to check what _prep_arcticdem returns for a batch."""
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
ensemble = DatasetEnsemble(
|
||||
grid="healpix",
|
||||
level=10,
|
||||
target="darts_mllabels",
|
||||
members=["ArcticDEM"],
|
||||
add_lonlat=True,
|
||||
filter_target=False,
|
||||
)
|
||||
|
||||
# Get targets
|
||||
targets = ensemble._read_target()
|
||||
print(f"Total targets: {len(targets)}")
|
||||
|
||||
# Get first batch of targets
|
||||
batch_targets = targets.iloc[:100]
|
||||
print(f"\nBatch targets: {len(batch_targets)}")
|
||||
print(f"Cell IDs in batch: {batch_targets['cell_id'].values[:5]}")
|
||||
|
||||
# Try to prep ArcticDEM for this batch
|
||||
print("\n" + "=" * 80)
|
||||
print("Calling _prep_arcticdem...")
|
||||
print("=" * 80)
|
||||
arcticdem_df = ensemble._prep_arcticdem(batch_targets)
|
||||
print(f"\nArcticDEM DataFrame shape: {arcticdem_df.shape}")
|
||||
print(f"ArcticDEM DataFrame index: {arcticdem_df.index[:5].tolist() if len(arcticdem_df) > 0 else 'EMPTY'}")
|
||||
print(
|
||||
f"ArcticDEM DataFrame columns ({len(arcticdem_df.columns)}): {arcticdem_df.columns[:10].tolist() if len(arcticdem_df.columns) > 0 else 'NO COLUMNS'}"
|
||||
)
|
||||
print(f"Number of non-NaN rows: {arcticdem_df.notna().any(axis=1).sum()}")
|
||||
72
tests/debug_feature_mismatch.py
Normal file
72
tests/debug_feature_mismatch.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""Debug script to identify feature mismatch between training and inference."""
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
# Test with level 6 (the actual level used in production)
|
||||
ensemble = DatasetEnsemble(
|
||||
grid="healpix",
|
||||
level=10,
|
||||
target="darts_mllabels",
|
||||
members=[
|
||||
"AlphaEarth",
|
||||
"ArcticDEM",
|
||||
"ERA5-yearly",
|
||||
"ERA5-seasonal",
|
||||
"ERA5-shoulder",
|
||||
],
|
||||
add_lonlat=True,
|
||||
filter_target=False,
|
||||
)
|
||||
|
||||
print("=" * 80)
|
||||
print("Creating training dataset...")
|
||||
print("=" * 80)
|
||||
training_data = ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
training_features = set(training_data.X.data.columns)
|
||||
print(f"\nTraining dataset created with {len(training_features)} features")
|
||||
print(f"Sample features: {sorted(list(training_features))[:10]}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Creating inference batch...")
|
||||
print("=" * 80)
|
||||
batch_generator = ensemble.create_batches(batch_size=100, cache_mode="n")
|
||||
batch = next(batch_generator, None)
|
||||
# for batch in batch_generator:
|
||||
if batch is None:
|
||||
print("ERROR: No batch created!")
|
||||
else:
|
||||
print(f"\nBatch created with {len(batch.columns)} columns")
|
||||
print(f"Batch columns: {sorted(batch.columns)[:15]}")
|
||||
|
||||
# Simulate the column dropping in predict_proba (inference.py)
|
||||
cols_to_drop = ["geometry"]
|
||||
if ensemble.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
|
||||
print(f"\nColumns to drop: {cols_to_drop}")
|
||||
|
||||
inference_batch = batch.drop(columns=cols_to_drop)
|
||||
inference_features = set(inference_batch.columns)
|
||||
|
||||
print(f"\nInference batch after dropping has {len(inference_features)} features")
|
||||
print(f"Sample features: {sorted(list(inference_features))[:10]}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("COMPARISON")
|
||||
print("=" * 80)
|
||||
print(f"Training features: {len(training_features)}")
|
||||
print(f"Inference features: {len(inference_features)}")
|
||||
|
||||
if training_features == inference_features:
|
||||
print("\n✅ SUCCESS: Features match perfectly!")
|
||||
else:
|
||||
print("\n❌ MISMATCH DETECTED!")
|
||||
only_in_training = training_features - inference_features
|
||||
only_in_inference = inference_features - training_features
|
||||
|
||||
if only_in_training:
|
||||
print(f"\n⚠️ Only in TRAINING ({len(only_in_training)}): {sorted(only_in_training)}")
|
||||
if only_in_inference:
|
||||
print(f"\n⚠️ Only in INFERENCE ({len(only_in_inference)}): {sorted(only_in_inference)}")
|
||||
310
tests/test_dataset.py
Normal file
310
tests/test_dataset.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
"""Tests for dataset.py module, specifically DatasetEnsemble class."""
|
||||
|
||||
import geopandas as gpd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ensemble():
|
||||
"""Create a sample DatasetEnsemble for testing with minimal data."""
|
||||
return DatasetEnsemble(
|
||||
grid="hex",
|
||||
level=3, # Use level 3 for much faster tests
|
||||
target="darts_rts",
|
||||
members=["AlphaEarth"], # Use only one member for faster tests
|
||||
add_lonlat=True,
|
||||
filter_target="darts_has_coverage", # Filter to reduce dataset size
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ensemble_mllabels():
|
||||
"""Create a sample DatasetEnsemble with mllabels target."""
|
||||
return DatasetEnsemble(
|
||||
grid="hex",
|
||||
level=3, # Use level 3 for much faster tests
|
||||
target="darts_mllabels",
|
||||
members=["AlphaEarth"], # Use only one member for faster tests
|
||||
add_lonlat=True,
|
||||
filter_target="dartsml_has_coverage", # Filter to reduce dataset size
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetEnsemble:
|
||||
"""Test suite for DatasetEnsemble class."""
|
||||
|
||||
def test_initialization(self, sample_ensemble):
|
||||
"""Test that DatasetEnsemble initializes correctly."""
|
||||
assert sample_ensemble.grid == "hex"
|
||||
assert sample_ensemble.level == 3
|
||||
assert sample_ensemble.target == "darts_rts"
|
||||
assert "AlphaEarth" in sample_ensemble.members
|
||||
assert sample_ensemble.add_lonlat is True
|
||||
|
||||
def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels):
|
||||
"""Test that covcol property returns correct column name."""
|
||||
assert sample_ensemble.covcol == "darts_has_coverage"
|
||||
assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage"
|
||||
|
||||
def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels):
|
||||
"""Test that taskcol returns correct column name for different tasks."""
|
||||
assert sample_ensemble.taskcol("binary") == "darts_has_rts"
|
||||
assert sample_ensemble.taskcol("count") == "darts_rts_count"
|
||||
assert sample_ensemble.taskcol("density") == "darts_rts_density"
|
||||
|
||||
assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts"
|
||||
assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count"
|
||||
assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density"
|
||||
|
||||
def test_create_returns_geodataframe(self, sample_ensemble):
|
||||
"""Test that create() returns a GeoDataFrame."""
|
||||
dataset = sample_ensemble.create(cache_mode="n")
|
||||
assert isinstance(dataset, gpd.GeoDataFrame)
|
||||
|
||||
def test_create_has_expected_columns(self, sample_ensemble):
|
||||
"""Test that create() returns dataset with expected columns."""
|
||||
dataset = sample_ensemble.create(cache_mode="n")
|
||||
|
||||
# Should have geometry column
|
||||
assert "geometry" in dataset.columns
|
||||
|
||||
# Should have lat/lon if add_lonlat is True
|
||||
if sample_ensemble.add_lonlat:
|
||||
assert "lon" in dataset.columns
|
||||
assert "lat" in dataset.columns
|
||||
|
||||
# Should have target columns (darts_*)
|
||||
assert any(col.startswith("darts_") for col in dataset.columns)
|
||||
|
||||
# Should have member data columns
|
||||
assert len(dataset.columns) > 3 # More than just geometry, lat, lon
|
||||
|
||||
def test_create_batches_consistency(self, sample_ensemble):
|
||||
"""Test that create_batches produces batches with consistent columns."""
|
||||
batch_size = 100
|
||||
batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n"))
|
||||
|
||||
if len(batches) == 0:
|
||||
pytest.skip("No batches created (dataset might be empty)")
|
||||
|
||||
# All batches should have the same columns
|
||||
first_batch_cols = set(batches[0].columns)
|
||||
for i, batch in enumerate(batches[1:], start=1):
|
||||
assert set(batch.columns) == first_batch_cols, (
|
||||
f"Batch {i} has different columns than batch 0. "
|
||||
f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}"
|
||||
)
|
||||
|
||||
def test_create_vs_create_batches_columns(self, sample_ensemble):
|
||||
"""Test that create() and create_batches() return datasets with the same columns."""
|
||||
full_dataset = sample_ensemble.create(cache_mode="n")
|
||||
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
|
||||
|
||||
if len(batches) == 0:
|
||||
pytest.skip("No batches created (dataset might be empty)")
|
||||
|
||||
# Columns should be identical
|
||||
full_cols = set(full_dataset.columns)
|
||||
batch_cols = set(batches[0].columns)
|
||||
|
||||
assert full_cols == batch_cols, (
|
||||
f"Column mismatch between create() and create_batches().\n"
|
||||
f"Only in create(): {full_cols - batch_cols}\n"
|
||||
f"Only in create_batches(): {batch_cols - full_cols}"
|
||||
)
|
||||
|
||||
def test_training_dataset_feature_columns(self, sample_ensemble):
|
||||
"""Test that create_cat_training_dataset creates proper feature columns."""
|
||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
|
||||
# Get the columns used for model inputs
|
||||
model_input_cols = set(training_data.X.data.columns)
|
||||
|
||||
# These columns should NOT be in model inputs
|
||||
assert "geometry" not in model_input_cols
|
||||
assert sample_ensemble.covcol not in model_input_cols
|
||||
assert sample_ensemble.taskcol("binary") not in model_input_cols
|
||||
|
||||
# No darts_* columns should be in model inputs
|
||||
for col in model_input_cols:
|
||||
assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs"
|
||||
|
||||
def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels):
|
||||
"""Test feature columns for mllabels target."""
|
||||
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
|
||||
|
||||
model_input_cols = set(training_data.X.data.columns)
|
||||
|
||||
# These columns should NOT be in model inputs
|
||||
assert "geometry" not in model_input_cols
|
||||
assert sample_ensemble_mllabels.covcol not in model_input_cols
|
||||
assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols
|
||||
|
||||
# No dartsml_* columns should be in model inputs
|
||||
for col in model_input_cols:
|
||||
assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs"
|
||||
|
||||
def test_inference_vs_training_feature_consistency(self, sample_ensemble):
|
||||
"""Test that inference batches have the same features as training data after column dropping.
|
||||
|
||||
This test simulates the workflow in training.py and inference.py to ensure
|
||||
feature consistency between training and inference.
|
||||
"""
|
||||
# Step 1: Create training dataset (as in training.py)
|
||||
# Use only a small subset by creating just one batch
|
||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
training_feature_cols = set(training_data.X.data.columns)
|
||||
|
||||
# Step 2: Create inference batch (as in inference.py)
|
||||
# Get just the first batch to speed up test
|
||||
batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n")
|
||||
batch = next(batch_generator, None)
|
||||
|
||||
if batch is None:
|
||||
pytest.skip("No batches created (dataset might be empty)")
|
||||
|
||||
# Simulate the column dropping in predict_proba
|
||||
cols_to_drop = ["geometry"]
|
||||
if sample_ensemble.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
|
||||
inference_batch = batch.drop(columns=cols_to_drop)
|
||||
inference_feature_cols = set(inference_batch.columns)
|
||||
|
||||
# The features should match!
|
||||
assert training_feature_cols == inference_feature_cols, (
|
||||
f"Feature mismatch between training and inference!\n"
|
||||
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
|
||||
f"Only in inference: {inference_feature_cols - training_feature_cols}\n"
|
||||
f"Training features ({len(training_feature_cols)}): {sorted(training_feature_cols)}\n"
|
||||
f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}"
|
||||
)
|
||||
|
||||
def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels):
|
||||
"""Test feature consistency for mllabels target."""
|
||||
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
|
||||
training_feature_cols = set(training_data.X.data.columns)
|
||||
|
||||
# Get just the first batch to speed up test
|
||||
batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n")
|
||||
batch = next(batch_generator, None)
|
||||
|
||||
if batch is None:
|
||||
pytest.skip("No batches created (dataset might be empty)")
|
||||
|
||||
# Simulate the column dropping in predict_proba
|
||||
cols_to_drop = ["geometry"]
|
||||
if sample_ensemble_mllabels.target == "darts_mllabels":
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||||
else:
|
||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||||
|
||||
inference_batch = batch.drop(columns=cols_to_drop)
|
||||
inference_feature_cols = set(inference_batch.columns)
|
||||
|
||||
assert training_feature_cols == inference_feature_cols, (
|
||||
f"Feature mismatch between training and inference!\n"
|
||||
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
|
||||
f"Only in inference: {inference_feature_cols - training_feature_cols}"
|
||||
)
|
||||
|
||||
def test_all_tasks_feature_consistency(self, sample_ensemble):
|
||||
"""Test that all task types produce consistent features."""
|
||||
tasks = ["binary", "count", "density"]
|
||||
feature_sets = {}
|
||||
|
||||
for task in tasks:
|
||||
training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu")
|
||||
feature_sets[task] = set(training_data.X.data.columns)
|
||||
|
||||
# All tasks should have the same features
|
||||
binary_features = feature_sets["binary"]
|
||||
for task, features in feature_sets.items():
|
||||
assert features == binary_features, (
|
||||
f"Task '{task}' has different features than 'binary'.\n"
|
||||
f"Difference: {features.symmetric_difference(binary_features)}"
|
||||
)
|
||||
|
||||
def test_training_dataset_shapes(self, sample_ensemble):
|
||||
"""Test that training dataset has correct shapes."""
|
||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
|
||||
n_features = len(training_data.X.data.columns)
|
||||
n_samples_train = training_data.X.train.shape[0]
|
||||
n_samples_test = training_data.X.test.shape[0]
|
||||
|
||||
# Check X shapes
|
||||
assert training_data.X.train.shape == (n_samples_train, n_features)
|
||||
assert training_data.X.test.shape == (n_samples_test, n_features)
|
||||
|
||||
# Check y shapes
|
||||
assert training_data.y.train.shape == (n_samples_train,)
|
||||
assert training_data.y.test.shape == (n_samples_test,)
|
||||
|
||||
# Check that train + test = total samples
|
||||
assert len(training_data.dataset) == n_samples_train + n_samples_test
|
||||
|
||||
def test_no_nan_in_training_features(self, sample_ensemble):
|
||||
"""Test that training features don't contain NaN values."""
|
||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||||
|
||||
# Convert to numpy for checking (handles both numpy and cupy arrays)
|
||||
X_train = np.asarray(training_data.X.train)
|
||||
X_test = np.asarray(training_data.X.test)
|
||||
|
||||
assert not np.isnan(X_train).any(), "Training features contain NaN values"
|
||||
assert not np.isnan(X_test).any(), "Test features contain NaN values"
|
||||
|
||||
def test_batch_coverage(self, sample_ensemble):
|
||||
"""Test that batches cover all data without duplication."""
|
||||
full_dataset = sample_ensemble.create(cache_mode="n")
|
||||
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
|
||||
|
||||
if len(batches) == 0:
|
||||
pytest.skip("No batches created (dataset might be empty)")
|
||||
|
||||
# Collect all cell_ids from batches
|
||||
batch_cell_ids = set()
|
||||
for batch in batches:
|
||||
batch_ids = set(batch.index)
|
||||
# Check for duplicates across batches
|
||||
overlap = batch_cell_ids.intersection(batch_ids)
|
||||
assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches"
|
||||
batch_cell_ids.update(batch_ids)
|
||||
|
||||
# Check that all cell_ids from full dataset are in batches
|
||||
full_cell_ids = set(full_dataset.index)
|
||||
assert batch_cell_ids == full_cell_ids, (
|
||||
f"Batch coverage mismatch.\n"
|
||||
f"Missing from batches: {full_cell_ids - batch_cell_ids}\n"
|
||||
f"Extra in batches: {batch_cell_ids - full_cell_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetEnsembleEdgeCases:
|
||||
"""Test edge cases and error handling."""
|
||||
|
||||
def test_invalid_task_raises_error(self, sample_ensemble):
|
||||
"""Test that invalid task raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid task"):
|
||||
sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore
|
||||
|
||||
def test_stats_method(self, sample_ensemble):
|
||||
"""Test that get_stats returns expected structure."""
|
||||
stats = sample_ensemble.get_stats()
|
||||
|
||||
assert "target" in stats
|
||||
assert "num_target_samples" in stats
|
||||
assert "members" in stats
|
||||
assert "total_features" in stats
|
||||
|
||||
# Check that members dict contains info for each member
|
||||
for member in sample_ensemble.members:
|
||||
assert member in stats["members"]
|
||||
assert "variables" in stats["members"][member]
|
||||
assert "num_features" in stats["members"][member]
|
||||
Loading…
Add table
Add a link
Reference in a new issue