311 lines
14 KiB
Python
311 lines
14 KiB
Python
|
|
"""Tests for dataset.py module, specifically DatasetEnsemble class."""
|
||
|
|
|
||
|
|
import geopandas as gpd
|
||
|
|
import numpy as np
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from entropice.ml.dataset import DatasetEnsemble
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def sample_ensemble():
|
||
|
|
"""Create a sample DatasetEnsemble for testing with minimal data."""
|
||
|
|
return DatasetEnsemble(
|
||
|
|
grid="hex",
|
||
|
|
level=3, # Use level 3 for much faster tests
|
||
|
|
target="darts_rts",
|
||
|
|
members=["AlphaEarth"], # Use only one member for faster tests
|
||
|
|
add_lonlat=True,
|
||
|
|
filter_target="darts_has_coverage", # Filter to reduce dataset size
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def sample_ensemble_mllabels():
|
||
|
|
"""Create a sample DatasetEnsemble with mllabels target."""
|
||
|
|
return DatasetEnsemble(
|
||
|
|
grid="hex",
|
||
|
|
level=3, # Use level 3 for much faster tests
|
||
|
|
target="darts_mllabels",
|
||
|
|
members=["AlphaEarth"], # Use only one member for faster tests
|
||
|
|
add_lonlat=True,
|
||
|
|
filter_target="dartsml_has_coverage", # Filter to reduce dataset size
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestDatasetEnsemble:
|
||
|
|
"""Test suite for DatasetEnsemble class."""
|
||
|
|
|
||
|
|
def test_initialization(self, sample_ensemble):
|
||
|
|
"""Test that DatasetEnsemble initializes correctly."""
|
||
|
|
assert sample_ensemble.grid == "hex"
|
||
|
|
assert sample_ensemble.level == 3
|
||
|
|
assert sample_ensemble.target == "darts_rts"
|
||
|
|
assert "AlphaEarth" in sample_ensemble.members
|
||
|
|
assert sample_ensemble.add_lonlat is True
|
||
|
|
|
||
|
|
def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels):
|
||
|
|
"""Test that covcol property returns correct column name."""
|
||
|
|
assert sample_ensemble.covcol == "darts_has_coverage"
|
||
|
|
assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage"
|
||
|
|
|
||
|
|
def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels):
|
||
|
|
"""Test that taskcol returns correct column name for different tasks."""
|
||
|
|
assert sample_ensemble.taskcol("binary") == "darts_has_rts"
|
||
|
|
assert sample_ensemble.taskcol("count") == "darts_rts_count"
|
||
|
|
assert sample_ensemble.taskcol("density") == "darts_rts_density"
|
||
|
|
|
||
|
|
assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts"
|
||
|
|
assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count"
|
||
|
|
assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density"
|
||
|
|
|
||
|
|
def test_create_returns_geodataframe(self, sample_ensemble):
|
||
|
|
"""Test that create() returns a GeoDataFrame."""
|
||
|
|
dataset = sample_ensemble.create(cache_mode="n")
|
||
|
|
assert isinstance(dataset, gpd.GeoDataFrame)
|
||
|
|
|
||
|
|
def test_create_has_expected_columns(self, sample_ensemble):
|
||
|
|
"""Test that create() returns dataset with expected columns."""
|
||
|
|
dataset = sample_ensemble.create(cache_mode="n")
|
||
|
|
|
||
|
|
# Should have geometry column
|
||
|
|
assert "geometry" in dataset.columns
|
||
|
|
|
||
|
|
# Should have lat/lon if add_lonlat is True
|
||
|
|
if sample_ensemble.add_lonlat:
|
||
|
|
assert "lon" in dataset.columns
|
||
|
|
assert "lat" in dataset.columns
|
||
|
|
|
||
|
|
# Should have target columns (darts_*)
|
||
|
|
assert any(col.startswith("darts_") for col in dataset.columns)
|
||
|
|
|
||
|
|
# Should have member data columns
|
||
|
|
assert len(dataset.columns) > 3 # More than just geometry, lat, lon
|
||
|
|
|
||
|
|
def test_create_batches_consistency(self, sample_ensemble):
|
||
|
|
"""Test that create_batches produces batches with consistent columns."""
|
||
|
|
batch_size = 100
|
||
|
|
batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n"))
|
||
|
|
|
||
|
|
if len(batches) == 0:
|
||
|
|
pytest.skip("No batches created (dataset might be empty)")
|
||
|
|
|
||
|
|
# All batches should have the same columns
|
||
|
|
first_batch_cols = set(batches[0].columns)
|
||
|
|
for i, batch in enumerate(batches[1:], start=1):
|
||
|
|
assert set(batch.columns) == first_batch_cols, (
|
||
|
|
f"Batch {i} has different columns than batch 0. "
|
||
|
|
f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_create_vs_create_batches_columns(self, sample_ensemble):
|
||
|
|
"""Test that create() and create_batches() return datasets with the same columns."""
|
||
|
|
full_dataset = sample_ensemble.create(cache_mode="n")
|
||
|
|
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
|
||
|
|
|
||
|
|
if len(batches) == 0:
|
||
|
|
pytest.skip("No batches created (dataset might be empty)")
|
||
|
|
|
||
|
|
# Columns should be identical
|
||
|
|
full_cols = set(full_dataset.columns)
|
||
|
|
batch_cols = set(batches[0].columns)
|
||
|
|
|
||
|
|
assert full_cols == batch_cols, (
|
||
|
|
f"Column mismatch between create() and create_batches().\n"
|
||
|
|
f"Only in create(): {full_cols - batch_cols}\n"
|
||
|
|
f"Only in create_batches(): {batch_cols - full_cols}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_training_dataset_feature_columns(self, sample_ensemble):
|
||
|
|
"""Test that create_cat_training_dataset creates proper feature columns."""
|
||
|
|
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||
|
|
|
||
|
|
# Get the columns used for model inputs
|
||
|
|
model_input_cols = set(training_data.X.data.columns)
|
||
|
|
|
||
|
|
# These columns should NOT be in model inputs
|
||
|
|
assert "geometry" not in model_input_cols
|
||
|
|
assert sample_ensemble.covcol not in model_input_cols
|
||
|
|
assert sample_ensemble.taskcol("binary") not in model_input_cols
|
||
|
|
|
||
|
|
# No darts_* columns should be in model inputs
|
||
|
|
for col in model_input_cols:
|
||
|
|
assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs"
|
||
|
|
|
||
|
|
def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels):
|
||
|
|
"""Test feature columns for mllabels target."""
|
||
|
|
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
|
||
|
|
|
||
|
|
model_input_cols = set(training_data.X.data.columns)
|
||
|
|
|
||
|
|
# These columns should NOT be in model inputs
|
||
|
|
assert "geometry" not in model_input_cols
|
||
|
|
assert sample_ensemble_mllabels.covcol not in model_input_cols
|
||
|
|
assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols
|
||
|
|
|
||
|
|
# No dartsml_* columns should be in model inputs
|
||
|
|
for col in model_input_cols:
|
||
|
|
assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs"
|
||
|
|
|
||
|
|
def test_inference_vs_training_feature_consistency(self, sample_ensemble):
|
||
|
|
"""Test that inference batches have the same features as training data after column dropping.
|
||
|
|
|
||
|
|
This test simulates the workflow in training.py and inference.py to ensure
|
||
|
|
feature consistency between training and inference.
|
||
|
|
"""
|
||
|
|
# Step 1: Create training dataset (as in training.py)
|
||
|
|
# Use only a small subset by creating just one batch
|
||
|
|
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||
|
|
training_feature_cols = set(training_data.X.data.columns)
|
||
|
|
|
||
|
|
# Step 2: Create inference batch (as in inference.py)
|
||
|
|
# Get just the first batch to speed up test
|
||
|
|
batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n")
|
||
|
|
batch = next(batch_generator, None)
|
||
|
|
|
||
|
|
if batch is None:
|
||
|
|
pytest.skip("No batches created (dataset might be empty)")
|
||
|
|
|
||
|
|
# Simulate the column dropping in predict_proba
|
||
|
|
cols_to_drop = ["geometry"]
|
||
|
|
if sample_ensemble.target == "darts_mllabels":
|
||
|
|
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||
|
|
else:
|
||
|
|
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||
|
|
|
||
|
|
inference_batch = batch.drop(columns=cols_to_drop)
|
||
|
|
inference_feature_cols = set(inference_batch.columns)
|
||
|
|
|
||
|
|
# The features should match!
|
||
|
|
assert training_feature_cols == inference_feature_cols, (
|
||
|
|
f"Feature mismatch between training and inference!\n"
|
||
|
|
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
|
||
|
|
f"Only in inference: {inference_feature_cols - training_feature_cols}\n"
|
||
|
|
f"Training features ({len(training_feature_cols)}): {sorted(training_feature_cols)}\n"
|
||
|
|
f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels):
|
||
|
|
"""Test feature consistency for mllabels target."""
|
||
|
|
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
|
||
|
|
training_feature_cols = set(training_data.X.data.columns)
|
||
|
|
|
||
|
|
# Get just the first batch to speed up test
|
||
|
|
batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n")
|
||
|
|
batch = next(batch_generator, None)
|
||
|
|
|
||
|
|
if batch is None:
|
||
|
|
pytest.skip("No batches created (dataset might be empty)")
|
||
|
|
|
||
|
|
# Simulate the column dropping in predict_proba
|
||
|
|
cols_to_drop = ["geometry"]
|
||
|
|
if sample_ensemble_mllabels.target == "darts_mllabels":
|
||
|
|
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
||
|
|
else:
|
||
|
|
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
||
|
|
|
||
|
|
inference_batch = batch.drop(columns=cols_to_drop)
|
||
|
|
inference_feature_cols = set(inference_batch.columns)
|
||
|
|
|
||
|
|
assert training_feature_cols == inference_feature_cols, (
|
||
|
|
f"Feature mismatch between training and inference!\n"
|
||
|
|
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
|
||
|
|
f"Only in inference: {inference_feature_cols - training_feature_cols}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_all_tasks_feature_consistency(self, sample_ensemble):
|
||
|
|
"""Test that all task types produce consistent features."""
|
||
|
|
tasks = ["binary", "count", "density"]
|
||
|
|
feature_sets = {}
|
||
|
|
|
||
|
|
for task in tasks:
|
||
|
|
training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu")
|
||
|
|
feature_sets[task] = set(training_data.X.data.columns)
|
||
|
|
|
||
|
|
# All tasks should have the same features
|
||
|
|
binary_features = feature_sets["binary"]
|
||
|
|
for task, features in feature_sets.items():
|
||
|
|
assert features == binary_features, (
|
||
|
|
f"Task '{task}' has different features than 'binary'.\n"
|
||
|
|
f"Difference: {features.symmetric_difference(binary_features)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_training_dataset_shapes(self, sample_ensemble):
|
||
|
|
"""Test that training dataset has correct shapes."""
|
||
|
|
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||
|
|
|
||
|
|
n_features = len(training_data.X.data.columns)
|
||
|
|
n_samples_train = training_data.X.train.shape[0]
|
||
|
|
n_samples_test = training_data.X.test.shape[0]
|
||
|
|
|
||
|
|
# Check X shapes
|
||
|
|
assert training_data.X.train.shape == (n_samples_train, n_features)
|
||
|
|
assert training_data.X.test.shape == (n_samples_test, n_features)
|
||
|
|
|
||
|
|
# Check y shapes
|
||
|
|
assert training_data.y.train.shape == (n_samples_train,)
|
||
|
|
assert training_data.y.test.shape == (n_samples_test,)
|
||
|
|
|
||
|
|
# Check that train + test = total samples
|
||
|
|
assert len(training_data.dataset) == n_samples_train + n_samples_test
|
||
|
|
|
||
|
|
def test_no_nan_in_training_features(self, sample_ensemble):
|
||
|
|
"""Test that training features don't contain NaN values."""
|
||
|
|
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
||
|
|
|
||
|
|
# Convert to numpy for checking (handles both numpy and cupy arrays)
|
||
|
|
X_train = np.asarray(training_data.X.train)
|
||
|
|
X_test = np.asarray(training_data.X.test)
|
||
|
|
|
||
|
|
assert not np.isnan(X_train).any(), "Training features contain NaN values"
|
||
|
|
assert not np.isnan(X_test).any(), "Test features contain NaN values"
|
||
|
|
|
||
|
|
def test_batch_coverage(self, sample_ensemble):
|
||
|
|
"""Test that batches cover all data without duplication."""
|
||
|
|
full_dataset = sample_ensemble.create(cache_mode="n")
|
||
|
|
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
|
||
|
|
|
||
|
|
if len(batches) == 0:
|
||
|
|
pytest.skip("No batches created (dataset might be empty)")
|
||
|
|
|
||
|
|
# Collect all cell_ids from batches
|
||
|
|
batch_cell_ids = set()
|
||
|
|
for batch in batches:
|
||
|
|
batch_ids = set(batch.index)
|
||
|
|
# Check for duplicates across batches
|
||
|
|
overlap = batch_cell_ids.intersection(batch_ids)
|
||
|
|
assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches"
|
||
|
|
batch_cell_ids.update(batch_ids)
|
||
|
|
|
||
|
|
# Check that all cell_ids from full dataset are in batches
|
||
|
|
full_cell_ids = set(full_dataset.index)
|
||
|
|
assert batch_cell_ids == full_cell_ids, (
|
||
|
|
f"Batch coverage mismatch.\n"
|
||
|
|
f"Missing from batches: {full_cell_ids - batch_cell_ids}\n"
|
||
|
|
f"Extra in batches: {batch_cell_ids - full_cell_ids}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestDatasetEnsembleEdgeCases:
|
||
|
|
"""Test edge cases and error handling."""
|
||
|
|
|
||
|
|
def test_invalid_task_raises_error(self, sample_ensemble):
|
||
|
|
"""Test that invalid task raises ValueError."""
|
||
|
|
with pytest.raises(ValueError, match="Invalid task"):
|
||
|
|
sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore
|
||
|
|
|
||
|
|
def test_stats_method(self, sample_ensemble):
|
||
|
|
"""Test that get_stats returns expected structure."""
|
||
|
|
stats = sample_ensemble.get_stats()
|
||
|
|
|
||
|
|
assert "target" in stats
|
||
|
|
assert "num_target_samples" in stats
|
||
|
|
assert "members" in stats
|
||
|
|
assert "total_features" in stats
|
||
|
|
|
||
|
|
# Check that members dict contains info for each member
|
||
|
|
for member in sample_ensemble.members:
|
||
|
|
assert member in stats["members"]
|
||
|
|
assert "variables" in stats["members"][member]
|
||
|
|
assert "num_features" in stats["members"][member]
|