"""Tests for dataset.py module, specifically DatasetEnsemble class.""" import geopandas as gpd import numpy as np import pytest from entropice.ml.dataset import DatasetEnsemble @pytest.fixture def sample_ensemble(): """Create a sample DatasetEnsemble for testing with minimal data.""" return DatasetEnsemble( grid="hex", level=3, # Use level 3 for much faster tests target="darts_rts", members=["AlphaEarth"], # Use only one member for faster tests add_lonlat=True, filter_target="darts_has_coverage", # Filter to reduce dataset size ) @pytest.fixture def sample_ensemble_mllabels(): """Create a sample DatasetEnsemble with mllabels target.""" return DatasetEnsemble( grid="hex", level=3, # Use level 3 for much faster tests target="darts_mllabels", members=["AlphaEarth"], # Use only one member for faster tests add_lonlat=True, filter_target="dartsml_has_coverage", # Filter to reduce dataset size ) class TestDatasetEnsemble: """Test suite for DatasetEnsemble class.""" def test_initialization(self, sample_ensemble): """Test that DatasetEnsemble initializes correctly.""" assert sample_ensemble.grid == "hex" assert sample_ensemble.level == 3 assert sample_ensemble.target == "darts_rts" assert "AlphaEarth" in sample_ensemble.members assert sample_ensemble.add_lonlat is True def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels): """Test that covcol property returns correct column name.""" assert sample_ensemble.covcol == "darts_has_coverage" assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage" def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels): """Test that taskcol returns correct column name for different tasks.""" assert sample_ensemble.taskcol("binary") == "darts_has_rts" assert sample_ensemble.taskcol("count") == "darts_rts_count" assert sample_ensemble.taskcol("density") == "darts_rts_density" assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts" assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count" assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density" def test_create_returns_geodataframe(self, sample_ensemble): """Test that create() returns a GeoDataFrame.""" dataset = sample_ensemble.create(cache_mode="n") assert isinstance(dataset, gpd.GeoDataFrame) def test_create_has_expected_columns(self, sample_ensemble): """Test that create() returns dataset with expected columns.""" dataset = sample_ensemble.create(cache_mode="n") # Should have geometry column assert "geometry" in dataset.columns # Should have lat/lon if add_lonlat is True if sample_ensemble.add_lonlat: assert "lon" in dataset.columns assert "lat" in dataset.columns # Should have target columns (darts_*) assert any(col.startswith("darts_") for col in dataset.columns) # Should have member data columns assert len(dataset.columns) > 3 # More than just geometry, lat, lon def test_create_batches_consistency(self, sample_ensemble): """Test that create_batches produces batches with consistent columns.""" batch_size = 100 batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n")) if len(batches) == 0: pytest.skip("No batches created (dataset might be empty)") # All batches should have the same columns first_batch_cols = set(batches[0].columns) for i, batch in enumerate(batches[1:], start=1): assert set(batch.columns) == first_batch_cols, ( f"Batch {i} has different columns than batch 0. " f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}" ) def test_create_vs_create_batches_columns(self, sample_ensemble): """Test that create() and create_batches() return datasets with the same columns.""" full_dataset = sample_ensemble.create(cache_mode="n") batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n")) if len(batches) == 0: pytest.skip("No batches created (dataset might be empty)") # Columns should be identical full_cols = set(full_dataset.columns) batch_cols = set(batches[0].columns) assert full_cols == batch_cols, ( f"Column mismatch between create() and create_batches().\n" f"Only in create(): {full_cols - batch_cols}\n" f"Only in create_batches(): {batch_cols - full_cols}" ) def test_training_dataset_feature_columns(self, sample_ensemble): """Test that create_cat_training_dataset creates proper feature columns.""" training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") # Get the columns used for model inputs model_input_cols = set(training_data.X.data.columns) # These columns should NOT be in model inputs assert "geometry" not in model_input_cols assert sample_ensemble.covcol not in model_input_cols assert sample_ensemble.taskcol("binary") not in model_input_cols # No darts_* columns should be in model inputs for col in model_input_cols: assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs" def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels): """Test feature columns for mllabels target.""" training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu") model_input_cols = set(training_data.X.data.columns) # These columns should NOT be in model inputs assert "geometry" not in model_input_cols assert sample_ensemble_mllabels.covcol not in model_input_cols assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols # No dartsml_* columns should be in model inputs for col in model_input_cols: assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs" def test_inference_vs_training_feature_consistency(self, sample_ensemble): """Test that inference batches have the same features as training data after column dropping. This test simulates the workflow in training.py and inference.py to ensure feature consistency between training and inference. """ # Step 1: Create training dataset (as in training.py) # Use only a small subset by creating just one batch training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") training_feature_cols = set(training_data.X.data.columns) # Step 2: Create inference batch (as in inference.py) # Get just the first batch to speed up test batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n") batch = next(batch_generator, None) if batch is None: pytest.skip("No batches created (dataset might be empty)") # Simulate the column dropping in predict_proba cols_to_drop = ["geometry"] if sample_ensemble.target == "darts_mllabels": cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")] else: cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] inference_batch = batch.drop(columns=cols_to_drop) inference_feature_cols = set(inference_batch.columns) # The features should match! assert training_feature_cols == inference_feature_cols, ( f"Feature mismatch between training and inference!\n" f"Only in training: {training_feature_cols - inference_feature_cols}\n" f"Only in inference: {inference_feature_cols - training_feature_cols}\n" f"Training features ({len(training_feature_cols)}): {sorted(training_feature_cols)}\n" f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}" ) def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels): """Test feature consistency for mllabels target.""" training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu") training_feature_cols = set(training_data.X.data.columns) # Get just the first batch to speed up test batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n") batch = next(batch_generator, None) if batch is None: pytest.skip("No batches created (dataset might be empty)") # Simulate the column dropping in predict_proba cols_to_drop = ["geometry"] if sample_ensemble_mllabels.target == "darts_mllabels": cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")] else: cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] inference_batch = batch.drop(columns=cols_to_drop) inference_feature_cols = set(inference_batch.columns) assert training_feature_cols == inference_feature_cols, ( f"Feature mismatch between training and inference!\n" f"Only in training: {training_feature_cols - inference_feature_cols}\n" f"Only in inference: {inference_feature_cols - training_feature_cols}" ) def test_all_tasks_feature_consistency(self, sample_ensemble): """Test that all task types produce consistent features.""" tasks = ["binary", "count", "density"] feature_sets = {} for task in tasks: training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu") feature_sets[task] = set(training_data.X.data.columns) # All tasks should have the same features binary_features = feature_sets["binary"] for task, features in feature_sets.items(): assert features == binary_features, ( f"Task '{task}' has different features than 'binary'.\n" f"Difference: {features.symmetric_difference(binary_features)}" ) def test_training_dataset_shapes(self, sample_ensemble): """Test that training dataset has correct shapes.""" training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") n_features = len(training_data.X.data.columns) n_samples_train = training_data.X.train.shape[0] n_samples_test = training_data.X.test.shape[0] # Check X shapes assert training_data.X.train.shape == (n_samples_train, n_features) assert training_data.X.test.shape == (n_samples_test, n_features) # Check y shapes assert training_data.y.train.shape == (n_samples_train,) assert training_data.y.test.shape == (n_samples_test,) # Check that train + test = total samples assert len(training_data.dataset) == n_samples_train + n_samples_test def test_no_nan_in_training_features(self, sample_ensemble): """Test that training features don't contain NaN values.""" training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") # Convert to numpy for checking (handles both numpy and cupy arrays) X_train = np.asarray(training_data.X.train) X_test = np.asarray(training_data.X.test) assert not np.isnan(X_train).any(), "Training features contain NaN values" assert not np.isnan(X_test).any(), "Test features contain NaN values" def test_batch_coverage(self, sample_ensemble): """Test that batches cover all data without duplication.""" full_dataset = sample_ensemble.create(cache_mode="n") batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n")) if len(batches) == 0: pytest.skip("No batches created (dataset might be empty)") # Collect all cell_ids from batches batch_cell_ids = set() for batch in batches: batch_ids = set(batch.index) # Check for duplicates across batches overlap = batch_cell_ids.intersection(batch_ids) assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches" batch_cell_ids.update(batch_ids) # Check that all cell_ids from full dataset are in batches full_cell_ids = set(full_dataset.index) assert batch_cell_ids == full_cell_ids, ( f"Batch coverage mismatch.\n" f"Missing from batches: {full_cell_ids - batch_cell_ids}\n" f"Extra in batches: {batch_cell_ids - full_cell_ids}" ) class TestDatasetEnsembleEdgeCases: """Test edge cases and error handling.""" def test_invalid_task_raises_error(self, sample_ensemble): """Test that invalid task raises ValueError.""" with pytest.raises(ValueError, match="Invalid task"): sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore def test_stats_method(self, sample_ensemble): """Test that get_stats returns expected structure.""" stats = sample_ensemble.get_stats() assert "target" in stats assert "num_target_samples" in stats assert "members" in stats assert "total_features" in stats # Check that members dict contains info for each member for member in sample_ensemble.members: assert member in stats["members"] assert "variables" in stats["members"][member] assert "num_features" in stats["members"][member]