- Added a section to display test metrics for model performance on the held-out test set. - Implemented confusion matrix visualization to analyze prediction breakdown. - Refactored sidebar settings to streamline metric selection and improve user experience. - Updated cross-validation statistics to compare CV performance with test metrics. - Enhanced DatasetEnsemble methods to handle empty data scenarios gracefully. - Introduced debug scripts to assist in identifying feature mismatches and validating dataset preparation. - Added comprehensive tests for DatasetEnsemble to ensure feature consistency and correct behavior across various scenarios.
310 lines
14 KiB
Python
310 lines
14 KiB
Python
"""Tests for dataset.py module, specifically DatasetEnsemble class."""
|
|
|
|
import geopandas as gpd
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from entropice.ml.dataset import DatasetEnsemble
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_ensemble():
|
|
"""Create a sample DatasetEnsemble for testing with minimal data."""
|
|
return DatasetEnsemble(
|
|
grid="hex",
|
|
level=3, # Use level 3 for much faster tests
|
|
target="darts_rts",
|
|
members=["AlphaEarth"], # Use only one member for faster tests
|
|
add_lonlat=True,
|
|
filter_target="darts_has_coverage", # Filter to reduce dataset size
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_ensemble_mllabels():
|
|
"""Create a sample DatasetEnsemble with mllabels target."""
|
|
return DatasetEnsemble(
|
|
grid="hex",
|
|
level=3, # Use level 3 for much faster tests
|
|
target="darts_mllabels",
|
|
members=["AlphaEarth"], # Use only one member for faster tests
|
|
add_lonlat=True,
|
|
filter_target="dartsml_has_coverage", # Filter to reduce dataset size
|
|
)
|
|
|
|
|
|
class TestDatasetEnsemble:
|
|
"""Test suite for DatasetEnsemble class."""
|
|
|
|
def test_initialization(self, sample_ensemble):
|
|
"""Test that DatasetEnsemble initializes correctly."""
|
|
assert sample_ensemble.grid == "hex"
|
|
assert sample_ensemble.level == 3
|
|
assert sample_ensemble.target == "darts_rts"
|
|
assert "AlphaEarth" in sample_ensemble.members
|
|
assert sample_ensemble.add_lonlat is True
|
|
|
|
def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels):
|
|
"""Test that covcol property returns correct column name."""
|
|
assert sample_ensemble.covcol == "darts_has_coverage"
|
|
assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage"
|
|
|
|
def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels):
|
|
"""Test that taskcol returns correct column name for different tasks."""
|
|
assert sample_ensemble.taskcol("binary") == "darts_has_rts"
|
|
assert sample_ensemble.taskcol("count") == "darts_rts_count"
|
|
assert sample_ensemble.taskcol("density") == "darts_rts_density"
|
|
|
|
assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts"
|
|
assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count"
|
|
assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density"
|
|
|
|
def test_create_returns_geodataframe(self, sample_ensemble):
|
|
"""Test that create() returns a GeoDataFrame."""
|
|
dataset = sample_ensemble.create(cache_mode="n")
|
|
assert isinstance(dataset, gpd.GeoDataFrame)
|
|
|
|
def test_create_has_expected_columns(self, sample_ensemble):
|
|
"""Test that create() returns dataset with expected columns."""
|
|
dataset = sample_ensemble.create(cache_mode="n")
|
|
|
|
# Should have geometry column
|
|
assert "geometry" in dataset.columns
|
|
|
|
# Should have lat/lon if add_lonlat is True
|
|
if sample_ensemble.add_lonlat:
|
|
assert "lon" in dataset.columns
|
|
assert "lat" in dataset.columns
|
|
|
|
# Should have target columns (darts_*)
|
|
assert any(col.startswith("darts_") for col in dataset.columns)
|
|
|
|
# Should have member data columns
|
|
assert len(dataset.columns) > 3 # More than just geometry, lat, lon
|
|
|
|
def test_create_batches_consistency(self, sample_ensemble):
|
|
"""Test that create_batches produces batches with consistent columns."""
|
|
batch_size = 100
|
|
batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n"))
|
|
|
|
if len(batches) == 0:
|
|
pytest.skip("No batches created (dataset might be empty)")
|
|
|
|
# All batches should have the same columns
|
|
first_batch_cols = set(batches[0].columns)
|
|
for i, batch in enumerate(batches[1:], start=1):
|
|
assert set(batch.columns) == first_batch_cols, (
|
|
f"Batch {i} has different columns than batch 0. "
|
|
f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}"
|
|
)
|
|
|
|
def test_create_vs_create_batches_columns(self, sample_ensemble):
|
|
"""Test that create() and create_batches() return datasets with the same columns."""
|
|
full_dataset = sample_ensemble.create(cache_mode="n")
|
|
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
|
|
|
|
if len(batches) == 0:
|
|
pytest.skip("No batches created (dataset might be empty)")
|
|
|
|
# Columns should be identical
|
|
full_cols = set(full_dataset.columns)
|
|
batch_cols = set(batches[0].columns)
|
|
|
|
assert full_cols == batch_cols, (
|
|
f"Column mismatch between create() and create_batches().\n"
|
|
f"Only in create(): {full_cols - batch_cols}\n"
|
|
f"Only in create_batches(): {batch_cols - full_cols}"
|
|
)
|
|
|
|
def test_training_dataset_feature_columns(self, sample_ensemble):
|
|
"""Test that create_cat_training_dataset creates proper feature columns."""
|
|
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
|
|
|
# Get the columns used for model inputs
|
|
model_input_cols = set(training_data.X.data.columns)
|
|
|
|
# These columns should NOT be in model inputs
|
|
assert "geometry" not in model_input_cols
|
|
assert sample_ensemble.covcol not in model_input_cols
|
|
assert sample_ensemble.taskcol("binary") not in model_input_cols
|
|
|
|
# No darts_* columns should be in model inputs
|
|
for col in model_input_cols:
|
|
assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs"
|
|
|
|
def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels):
|
|
"""Test feature columns for mllabels target."""
|
|
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
|
|
|
|
model_input_cols = set(training_data.X.data.columns)
|
|
|
|
# These columns should NOT be in model inputs
|
|
assert "geometry" not in model_input_cols
|
|
assert sample_ensemble_mllabels.covcol not in model_input_cols
|
|
assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols
|
|
|
|
# No dartsml_* columns should be in model inputs
|
|
for col in model_input_cols:
|
|
assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs"
|
|
|
|
def test_inference_vs_training_feature_consistency(self, sample_ensemble):
|
|
"""Test that inference batches have the same features as training data after column dropping.
|
|
|
|
This test simulates the workflow in training.py and inference.py to ensure
|
|
feature consistency between training and inference.
|
|
"""
|
|
# Step 1: Create training dataset (as in training.py)
|
|
# Use only a small subset by creating just one batch
|
|
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
|
training_feature_cols = set(training_data.X.data.columns)
|
|
|
|
# Step 2: Create inference batch (as in inference.py)
|
|
# Get just the first batch to speed up test
|
|
batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n")
|
|
batch = next(batch_generator, None)
|
|
|
|
if batch is None:
|
|
pytest.skip("No batches created (dataset might be empty)")
|
|
|
|
# Simulate the column dropping in predict_proba
|
|
cols_to_drop = ["geometry"]
|
|
if sample_ensemble.target == "darts_mllabels":
|
|
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
|
else:
|
|
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
|
|
|
inference_batch = batch.drop(columns=cols_to_drop)
|
|
inference_feature_cols = set(inference_batch.columns)
|
|
|
|
# The features should match!
|
|
assert training_feature_cols == inference_feature_cols, (
|
|
f"Feature mismatch between training and inference!\n"
|
|
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
|
|
f"Only in inference: {inference_feature_cols - training_feature_cols}\n"
|
|
f"Training features ({len(training_feature_cols)}): {sorted(training_feature_cols)}\n"
|
|
f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}"
|
|
)
|
|
|
|
def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels):
|
|
"""Test feature consistency for mllabels target."""
|
|
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
|
|
training_feature_cols = set(training_data.X.data.columns)
|
|
|
|
# Get just the first batch to speed up test
|
|
batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n")
|
|
batch = next(batch_generator, None)
|
|
|
|
if batch is None:
|
|
pytest.skip("No batches created (dataset might be empty)")
|
|
|
|
# Simulate the column dropping in predict_proba
|
|
cols_to_drop = ["geometry"]
|
|
if sample_ensemble_mllabels.target == "darts_mllabels":
|
|
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
|
else:
|
|
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
|
|
|
inference_batch = batch.drop(columns=cols_to_drop)
|
|
inference_feature_cols = set(inference_batch.columns)
|
|
|
|
assert training_feature_cols == inference_feature_cols, (
|
|
f"Feature mismatch between training and inference!\n"
|
|
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
|
|
f"Only in inference: {inference_feature_cols - training_feature_cols}"
|
|
)
|
|
|
|
def test_all_tasks_feature_consistency(self, sample_ensemble):
|
|
"""Test that all task types produce consistent features."""
|
|
tasks = ["binary", "count", "density"]
|
|
feature_sets = {}
|
|
|
|
for task in tasks:
|
|
training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu")
|
|
feature_sets[task] = set(training_data.X.data.columns)
|
|
|
|
# All tasks should have the same features
|
|
binary_features = feature_sets["binary"]
|
|
for task, features in feature_sets.items():
|
|
assert features == binary_features, (
|
|
f"Task '{task}' has different features than 'binary'.\n"
|
|
f"Difference: {features.symmetric_difference(binary_features)}"
|
|
)
|
|
|
|
def test_training_dataset_shapes(self, sample_ensemble):
|
|
"""Test that training dataset has correct shapes."""
|
|
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
|
|
|
n_features = len(training_data.X.data.columns)
|
|
n_samples_train = training_data.X.train.shape[0]
|
|
n_samples_test = training_data.X.test.shape[0]
|
|
|
|
# Check X shapes
|
|
assert training_data.X.train.shape == (n_samples_train, n_features)
|
|
assert training_data.X.test.shape == (n_samples_test, n_features)
|
|
|
|
# Check y shapes
|
|
assert training_data.y.train.shape == (n_samples_train,)
|
|
assert training_data.y.test.shape == (n_samples_test,)
|
|
|
|
# Check that train + test = total samples
|
|
assert len(training_data.dataset) == n_samples_train + n_samples_test
|
|
|
|
def test_no_nan_in_training_features(self, sample_ensemble):
|
|
"""Test that training features don't contain NaN values."""
|
|
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
|
|
|
# Convert to numpy for checking (handles both numpy and cupy arrays)
|
|
X_train = np.asarray(training_data.X.train)
|
|
X_test = np.asarray(training_data.X.test)
|
|
|
|
assert not np.isnan(X_train).any(), "Training features contain NaN values"
|
|
assert not np.isnan(X_test).any(), "Test features contain NaN values"
|
|
|
|
def test_batch_coverage(self, sample_ensemble):
|
|
"""Test that batches cover all data without duplication."""
|
|
full_dataset = sample_ensemble.create(cache_mode="n")
|
|
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
|
|
|
|
if len(batches) == 0:
|
|
pytest.skip("No batches created (dataset might be empty)")
|
|
|
|
# Collect all cell_ids from batches
|
|
batch_cell_ids = set()
|
|
for batch in batches:
|
|
batch_ids = set(batch.index)
|
|
# Check for duplicates across batches
|
|
overlap = batch_cell_ids.intersection(batch_ids)
|
|
assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches"
|
|
batch_cell_ids.update(batch_ids)
|
|
|
|
# Check that all cell_ids from full dataset are in batches
|
|
full_cell_ids = set(full_dataset.index)
|
|
assert batch_cell_ids == full_cell_ids, (
|
|
f"Batch coverage mismatch.\n"
|
|
f"Missing from batches: {full_cell_ids - batch_cell_ids}\n"
|
|
f"Extra in batches: {batch_cell_ids - full_cell_ids}"
|
|
)
|
|
|
|
|
|
class TestDatasetEnsembleEdgeCases:
|
|
"""Test edge cases and error handling."""
|
|
|
|
def test_invalid_task_raises_error(self, sample_ensemble):
|
|
"""Test that invalid task raises ValueError."""
|
|
with pytest.raises(ValueError, match="Invalid task"):
|
|
sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore
|
|
|
|
def test_stats_method(self, sample_ensemble):
|
|
"""Test that get_stats returns expected structure."""
|
|
stats = sample_ensemble.get_stats()
|
|
|
|
assert "target" in stats
|
|
assert "num_target_samples" in stats
|
|
assert "members" in stats
|
|
assert "total_features" in stats
|
|
|
|
# Check that members dict contains info for each member
|
|
for member in sample_ensemble.members:
|
|
assert member in stats["members"]
|
|
assert "variables" in stats["members"][member]
|
|
assert "num_features" in stats["members"][member]
|