entropice/tests/test_dataset.py

365 lines
16 KiB
Python
Raw Normal View History

"""Tests for dataset.py module, specifically DatasetEnsemble class."""
2026-01-19 17:05:44 +01:00
from collections.abc import Generator
import geopandas as gpd
import numpy as np
2026-01-19 17:05:44 +01:00
import pandas as pd
import pytest
2026-01-19 17:05:44 +01:00
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
from entropice.utils.types import all_target_datasets, all_tasks
@pytest.fixture
2026-01-19 17:05:44 +01:00
def sample_ensemble() -> Generator[DatasetEnsemble]:
"""Create a sample DatasetEnsemble for testing with minimal data."""
2026-01-19 17:05:44 +01:00
yield DatasetEnsemble(
grid="hex",
level=3, # Use level 3 for much faster tests
members=["AlphaEarth"], # Use only one member for faster tests
)
@pytest.fixture
2026-01-19 17:05:44 +01:00
def sample_ensemble_v2() -> Generator[DatasetEnsemble]:
"""Create a sample DatasetEnsemble for testing with v2 target."""
yield DatasetEnsemble(
grid="hex",
level=3, # Use level 3 for much faster tests
members=["AlphaEarth"], # Use only one member for faster tests
)
class TestDatasetEnsemble:
"""Test suite for DatasetEnsemble class."""
2026-01-19 17:05:44 +01:00
def test_initialization(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that DatasetEnsemble initializes correctly."""
assert sample_ensemble.grid == "hex"
assert sample_ensemble.level == 3
assert "AlphaEarth" in sample_ensemble.members
2026-01-19 17:05:44 +01:00
def test_get_targets_returns_geodataframe(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that get_targets() returns a GeoDataFrame."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
assert isinstance(targets, gpd.GeoDataFrame)
def test_get_targets_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that get_targets() returns expected columns."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
# Should have required columns
assert "geometry" in targets.columns
assert "y" in targets.columns # target label
assert "z" in targets.columns # raw value
assert targets.index.name == "cell_id"
def test_get_targets_binary_task(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that binary task creates proper labels."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
# Binary task should have categorical y with "No RTS" and "RTS"
assert targets["y"].dtype.name == "category"
assert set(targets["y"].cat.categories) == {"No RTS", "RTS"}
# z should be numeric (count)
assert pd.api.types.is_numeric_dtype(targets["z"])
def test_get_targets_count_task(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that count task creates proper labels."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="count")
2026-01-19 17:05:44 +01:00
# Count task should have numeric y and z that are identical
assert pd.api.types.is_numeric_dtype(targets["y"])
assert pd.api.types.is_numeric_dtype(targets["z"])
assert (targets["y"] == targets["z"]).all()
2026-01-19 17:05:44 +01:00
def test_get_targets_density_task(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that density task creates proper labels."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="density")
2026-01-19 17:05:44 +01:00
# Density task should have numeric y and z that are identical
assert pd.api.types.is_numeric_dtype(targets["y"])
assert pd.api.types.is_numeric_dtype(targets["z"])
assert (targets["y"] == targets["z"]).all()
2026-01-19 17:05:44 +01:00
def test_make_features_returns_dataframe(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that make_features() returns a DataFrame."""
# Get a small subset of cell_ids for faster testing
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
cell_ids: pd.Series = targets.index.to_series().head(50)
2026-01-19 17:05:44 +01:00
features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none")
assert isinstance(features, pd.DataFrame)
2026-01-19 17:05:44 +01:00
def test_make_features_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that make_features() returns expected feature columns."""
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
cell_ids: pd.Series = targets.index.to_series().head(50)
features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none")
# Should NOT have geometry column
assert "geometry" not in features.columns
2026-02-12 18:34:35 +01:00
# Should have location columns
assert "x" in features.columns
assert "y" in features.columns
2026-01-19 17:05:44 +01:00
# Should have grid property columns
assert "cell_area" in features.columns
assert "land_area" in features.columns
2026-01-19 17:05:44 +01:00
# Should have member feature columns
assert any(col.startswith("embeddings_") for col in features.columns)
2026-01-19 17:05:44 +01:00
def test_inference_df_batches_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that create_inference_df produces batches with consistent columns."""
batch_size: int = 100
batches: list[pd.DataFrame] = list(
sample_ensemble.create_inference_df(batch_size=batch_size, cache_mode="none")
)
if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)")
# All batches should have the same columns
first_batch_cols = set(batches[0].columns)
for i, batch in enumerate(batches[1:], start=1):
assert set(batch.columns) == first_batch_cols, (
f"Batch {i} has different columns than batch 0. "
f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}"
)
2026-01-19 17:05:44 +01:00
def test_training_set_structure(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that create_training_set creates proper structure."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
2026-01-19 17:05:44 +01:00
# Check that TrainingSet has all expected attributes
assert hasattr(training_set, "targets")
assert hasattr(training_set, "features")
assert hasattr(training_set, "X")
assert hasattr(training_set, "y")
assert hasattr(training_set, "z")
assert hasattr(training_set, "split")
def test_training_set_feature_columns(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that create_training_set creates proper feature columns."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
2026-01-19 17:05:44 +01:00
# Get the feature column names
feature_cols: set[str] = set(training_set.features.columns)
2026-01-19 17:05:44 +01:00
# These columns should NOT be in features
assert "geometry" not in feature_cols
# Note: "y" and "x" are spatial coordinates from the grid and are valid features
# The target labels are stored in training_set.targets["y"], not in features
2026-01-19 17:05:44 +01:00
def test_training_set_no_target_leakage(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that target information doesn't leak into features."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
2026-01-19 17:05:44 +01:00
# Features should not contain target-related columns
feature_cols: pd.Index = training_set.features.columns
assert all(not col.startswith("darts_") for col in feature_cols)
2026-01-19 17:05:44 +01:00
def test_inference_vs_training_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that inference batches have the same features as training data.
This test simulates the workflow in training.py and inference.py to ensure
feature consistency between training and inference.
"""
2026-01-19 17:05:44 +01:00
# Step 1: Create training set (as in training.py)
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
training_feature_cols: set[str] = set(training_set.features.columns)
# Step 2: Create inference batch (as in inference.py)
2026-01-19 17:05:44 +01:00
batch_generator: Generator[pd.DataFrame] = sample_ensemble.create_inference_df(
batch_size=100, cache_mode="none"
)
batch: pd.DataFrame | None = next(batch_generator, None)
if batch is None:
pytest.skip("No batches created (dataset might be empty)")
2026-01-19 17:05:44 +01:00
assert batch is not None # For type checker
inference_feature_cols: set[str] = set(batch.columns)
# The features should match!
assert training_feature_cols == inference_feature_cols, (
f"Feature mismatch between training and inference!\n"
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
f"Only in inference: {inference_feature_cols - training_feature_cols}\n"
f"Training features ({len(training_feature_cols)}): {sorted(training_feature_cols)}\n"
f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}"
)
2026-01-19 17:05:44 +01:00
def test_all_tasks_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that all task types produce consistent features."""
2026-01-19 17:05:44 +01:00
feature_sets: dict[str, set[str]] = {}
for task in all_tasks:
training_set: TrainingSet = sample_ensemble.create_training_set(
task=task,
target="darts_v1",
device="cpu",
cache_mode="none",
)
feature_sets[task] = set(training_set.features.columns)
# All tasks should have the same features
2026-01-19 17:05:44 +01:00
binary_features: set[str] = feature_sets["binary"]
for task, features in feature_sets.items():
assert features == binary_features, (
f"Task '{task}' has different features than 'binary'.\n"
f"Difference: {features.symmetric_difference(binary_features)}"
)
2026-01-19 17:05:44 +01:00
def test_training_set_shapes(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that training set has correct shapes."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
2026-01-19 17:05:44 +01:00
n_features: int = len(training_set.features.columns)
n_samples_train: int = training_set.X.train.shape[0]
n_samples_test: int = training_set.X.test.shape[0]
# Check X shapes
2026-01-19 17:05:44 +01:00
assert training_set.X.train.shape == (n_samples_train, n_features)
assert training_set.X.test.shape == (n_samples_test, n_features)
# Check y shapes
2026-01-19 17:05:44 +01:00
assert training_set.y.train.shape == (n_samples_train,)
assert training_set.y.test.shape == (n_samples_test,)
# Check that train + test = total samples
2026-01-19 17:05:44 +01:00
assert len(training_set) == n_samples_train + n_samples_test
2026-01-19 17:05:44 +01:00
def test_no_nan_in_training_features(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that training features don't contain NaN values."""
2026-01-19 17:05:44 +01:00
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
# Convert to numpy for checking (handles both numpy and cupy arrays)
2026-01-19 17:05:44 +01:00
X_train: np.ndarray = np.asarray(training_set.X.train) # noqa: N806
X_test: np.ndarray = np.asarray(training_set.X.test) # noqa: N806
assert not np.isnan(X_train).any(), "Training features contain NaN values"
assert not np.isnan(X_test).any(), "Test features contain NaN values"
2026-01-19 17:05:44 +01:00
def test_batch_coverage(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that batches cover all data without duplication."""
2026-01-19 17:05:44 +01:00
batches: list[pd.DataFrame] = list(sample_ensemble.create_inference_df(batch_size=100, cache_mode="none"))
if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)")
# Collect all cell_ids from batches
batch_cell_ids = set()
for batch in batches:
batch_ids = set(batch.index)
# Check for duplicates across batches
overlap = batch_cell_ids.intersection(batch_ids)
assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches"
batch_cell_ids.update(batch_ids)
2026-01-19 17:05:44 +01:00
# Check that all cell_ids from grid are in batches
all_cell_ids = set(sample_ensemble.cell_ids.to_numpy())
assert batch_cell_ids == all_cell_ids, (
f"Batch coverage mismatch.\n"
2026-01-19 17:05:44 +01:00
f"Missing from batches: {all_cell_ids - batch_cell_ids}\n"
f"Extra in batches: {batch_cell_ids - all_cell_ids}"
)
class TestDatasetEnsembleEdgeCases:
"""Test edge cases and error handling."""
2026-01-19 17:05:44 +01:00
def test_invalid_task_raises_error(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that invalid task raises ValueError."""
2026-01-19 17:05:44 +01:00
with pytest.raises(NotImplementedError, match=r"Task .* not supported"):
sample_ensemble.create_training_set(task="invalid", target="darts_v1", device="cpu") # type: ignore[arg-type]
def test_target_labels_property(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that TrainingSet provides target labels for categorical tasks."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
# Binary task should have labels
labels: list[str] | None = training_set.target_labels
assert labels is not None
assert isinstance(labels, list)
assert "No RTS" in labels
assert "RTS" in labels
def test_target_labels_none_for_regression(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that regression tasks return None for target_labels."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="count", target="darts_v1", device="cpu", cache_mode="none"
)
# Regression task should return None
assert training_set.target_labels is None
def test_to_dataframe_method(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that to_dataframe returns expected structure."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
# Get full dataset
df_full: pd.DataFrame = training_set.to_dataframe(split=None)
assert isinstance(df_full, pd.DataFrame)
assert "label" in df_full.columns
assert len(df_full) == len(training_set)
# Get train split
df_train = training_set.to_dataframe(split="train")
assert len(df_train) < len(df_full)
assert all(training_set.split[df_train.index] == "train")
# Get test split
df_test = training_set.to_dataframe(split="test")
assert len(df_test) < len(df_full)
assert all(training_set.split[df_test.index] == "test")
# Train + test should equal full
assert len(df_train) + len(df_test) == len(df_full)
def test_different_targets(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that different target datasets can be used."""
for target in all_target_datasets:
try:
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary",
target=target,
device="cpu",
cache_mode="none",
)
assert len(training_set) > 0
except Exception as e:
pytest.skip(f"Target {target} not available: {e}")
def test_feature_names_property(self, sample_ensemble: DatasetEnsemble) -> None:
"""Test that feature_names property returns list of feature names."""
training_set: TrainingSet = sample_ensemble.create_training_set(
task="binary", target="darts_v1", device="cpu", cache_mode="none"
)
feature_names: list[str] = training_set.feature_names
assert isinstance(feature_names, list)
assert len(feature_names) == training_set.features.shape[1]
assert all(isinstance(name, str) for name in feature_names)