Fix dataset test
This commit is contained in:
parent
87a0d03af1
commit
073502c51d
1 changed files with 234 additions and 176 deletions
|
|
@ -1,91 +1,129 @@
|
||||||
"""Tests for dataset.py module, specifically DatasetEnsemble class."""
|
"""Tests for dataset.py module, specifically DatasetEnsemble class."""
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from entropice.ml.dataset import DatasetEnsemble
|
from entropice.ml.dataset import DatasetEnsemble, TrainingSet
|
||||||
|
from entropice.utils.types import all_target_datasets, all_tasks
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_ensemble():
|
def sample_ensemble() -> Generator[DatasetEnsemble]:
|
||||||
"""Create a sample DatasetEnsemble for testing with minimal data."""
|
"""Create a sample DatasetEnsemble for testing with minimal data."""
|
||||||
return DatasetEnsemble(
|
yield DatasetEnsemble(
|
||||||
grid="hex",
|
grid="hex",
|
||||||
level=3, # Use level 3 for much faster tests
|
level=3, # Use level 3 for much faster tests
|
||||||
target="darts_rts",
|
|
||||||
members=["AlphaEarth"], # Use only one member for faster tests
|
members=["AlphaEarth"], # Use only one member for faster tests
|
||||||
add_lonlat=True,
|
add_lonlat=True,
|
||||||
filter_target="darts_has_coverage", # Filter to reduce dataset size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_ensemble_mllabels():
|
def sample_ensemble_v2() -> Generator[DatasetEnsemble]:
|
||||||
"""Create a sample DatasetEnsemble with mllabels target."""
|
"""Create a sample DatasetEnsemble for testing with v2 target."""
|
||||||
return DatasetEnsemble(
|
yield DatasetEnsemble(
|
||||||
grid="hex",
|
grid="hex",
|
||||||
level=3, # Use level 3 for much faster tests
|
level=3, # Use level 3 for much faster tests
|
||||||
target="darts_mllabels",
|
|
||||||
members=["AlphaEarth"], # Use only one member for faster tests
|
members=["AlphaEarth"], # Use only one member for faster tests
|
||||||
add_lonlat=True,
|
add_lonlat=True,
|
||||||
filter_target="dartsml_has_coverage", # Filter to reduce dataset size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestDatasetEnsemble:
|
class TestDatasetEnsemble:
|
||||||
"""Test suite for DatasetEnsemble class."""
|
"""Test suite for DatasetEnsemble class."""
|
||||||
|
|
||||||
def test_initialization(self, sample_ensemble):
|
def test_initialization(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that DatasetEnsemble initializes correctly."""
|
"""Test that DatasetEnsemble initializes correctly."""
|
||||||
assert sample_ensemble.grid == "hex"
|
assert sample_ensemble.grid == "hex"
|
||||||
assert sample_ensemble.level == 3
|
assert sample_ensemble.level == 3
|
||||||
assert sample_ensemble.target == "darts_rts"
|
|
||||||
assert "AlphaEarth" in sample_ensemble.members
|
assert "AlphaEarth" in sample_ensemble.members
|
||||||
assert sample_ensemble.add_lonlat is True
|
assert sample_ensemble.add_lonlat is True
|
||||||
|
|
||||||
def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels):
|
def test_get_targets_returns_geodataframe(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that covcol property returns correct column name."""
|
"""Test that get_targets() returns a GeoDataFrame."""
|
||||||
assert sample_ensemble.covcol == "darts_has_coverage"
|
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
|
||||||
assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage"
|
assert isinstance(targets, gpd.GeoDataFrame)
|
||||||
|
|
||||||
def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels):
|
def test_get_targets_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that taskcol returns correct column name for different tasks."""
|
"""Test that get_targets() returns expected columns."""
|
||||||
assert sample_ensemble.taskcol("binary") == "darts_has_rts"
|
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
|
||||||
assert sample_ensemble.taskcol("count") == "darts_rts_count"
|
|
||||||
assert sample_ensemble.taskcol("density") == "darts_rts_density"
|
|
||||||
|
|
||||||
assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts"
|
# Should have required columns
|
||||||
assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count"
|
assert "geometry" in targets.columns
|
||||||
assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density"
|
assert "y" in targets.columns # target label
|
||||||
|
assert "z" in targets.columns # raw value
|
||||||
|
assert targets.index.name == "cell_id"
|
||||||
|
|
||||||
def test_create_returns_geodataframe(self, sample_ensemble):
|
def test_get_targets_binary_task(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that create() returns a GeoDataFrame."""
|
"""Test that binary task creates proper labels."""
|
||||||
dataset = sample_ensemble.create(cache_mode="n")
|
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
|
||||||
assert isinstance(dataset, gpd.GeoDataFrame)
|
|
||||||
|
|
||||||
def test_create_has_expected_columns(self, sample_ensemble):
|
# Binary task should have categorical y with "No RTS" and "RTS"
|
||||||
"""Test that create() returns dataset with expected columns."""
|
assert targets["y"].dtype.name == "category"
|
||||||
dataset = sample_ensemble.create(cache_mode="n")
|
assert set(targets["y"].cat.categories) == {"No RTS", "RTS"}
|
||||||
|
|
||||||
# Should have geometry column
|
# z should be numeric (count)
|
||||||
assert "geometry" in dataset.columns
|
assert pd.api.types.is_numeric_dtype(targets["z"])
|
||||||
|
|
||||||
# Should have lat/lon if add_lonlat is True
|
def test_get_targets_count_task(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
|
"""Test that count task creates proper labels."""
|
||||||
|
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="count")
|
||||||
|
|
||||||
|
# Count task should have numeric y and z that are identical
|
||||||
|
assert pd.api.types.is_numeric_dtype(targets["y"])
|
||||||
|
assert pd.api.types.is_numeric_dtype(targets["z"])
|
||||||
|
assert (targets["y"] == targets["z"]).all()
|
||||||
|
|
||||||
|
def test_get_targets_density_task(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
|
"""Test that density task creates proper labels."""
|
||||||
|
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="density")
|
||||||
|
|
||||||
|
# Density task should have numeric y and z that are identical
|
||||||
|
assert pd.api.types.is_numeric_dtype(targets["y"])
|
||||||
|
assert pd.api.types.is_numeric_dtype(targets["z"])
|
||||||
|
assert (targets["y"] == targets["z"]).all()
|
||||||
|
|
||||||
|
def test_make_features_returns_dataframe(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
|
"""Test that make_features() returns a DataFrame."""
|
||||||
|
# Get a small subset of cell_ids for faster testing
|
||||||
|
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
|
||||||
|
cell_ids: pd.Series = targets.index.to_series().head(50)
|
||||||
|
|
||||||
|
features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none")
|
||||||
|
assert isinstance(features, pd.DataFrame)
|
||||||
|
|
||||||
|
def test_make_features_has_expected_columns(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
|
"""Test that make_features() returns expected feature columns."""
|
||||||
|
targets: gpd.GeoDataFrame = sample_ensemble.get_targets(target="darts_v1", task="binary")
|
||||||
|
cell_ids: pd.Series = targets.index.to_series().head(50)
|
||||||
|
|
||||||
|
features: pd.DataFrame = sample_ensemble.make_features(cell_ids=cell_ids, cache_mode="none")
|
||||||
|
|
||||||
|
# Should NOT have geometry column
|
||||||
|
assert "geometry" not in features.columns
|
||||||
|
|
||||||
|
# Should have location columns if add_lonlat is True
|
||||||
if sample_ensemble.add_lonlat:
|
if sample_ensemble.add_lonlat:
|
||||||
assert "lon" in dataset.columns
|
assert "x" in features.columns
|
||||||
assert "lat" in dataset.columns
|
assert "y" in features.columns
|
||||||
|
|
||||||
# Should have target columns (darts_*)
|
# Should have grid property columns
|
||||||
assert any(col.startswith("darts_") for col in dataset.columns)
|
assert "cell_area" in features.columns
|
||||||
|
assert "land_area" in features.columns
|
||||||
|
|
||||||
# Should have member data columns
|
# Should have member feature columns
|
||||||
assert len(dataset.columns) > 3 # More than just geometry, lat, lon
|
assert any(col.startswith("embeddings_") for col in features.columns)
|
||||||
|
|
||||||
def test_create_batches_consistency(self, sample_ensemble):
|
def test_inference_df_batches_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that create_batches produces batches with consistent columns."""
|
"""Test that create_inference_df produces batches with consistent columns."""
|
||||||
batch_size = 100
|
batch_size: int = 100
|
||||||
batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n"))
|
batches: list[pd.DataFrame] = list(
|
||||||
|
sample_ensemble.create_inference_df(batch_size=batch_size, cache_mode="none")
|
||||||
|
)
|
||||||
|
|
||||||
if len(batches) == 0:
|
if len(batches) == 0:
|
||||||
pytest.skip("No batches created (dataset might be empty)")
|
pytest.skip("No batches created (dataset might be empty)")
|
||||||
|
|
@ -98,83 +136,67 @@ class TestDatasetEnsemble:
|
||||||
f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}"
|
f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_create_vs_create_batches_columns(self, sample_ensemble):
|
def test_training_set_structure(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that create() and create_batches() return datasets with the same columns."""
|
"""Test that create_training_set creates proper structure."""
|
||||||
full_dataset = sample_ensemble.create(cache_mode="n")
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
|
task="binary", target="darts_v1", device="cpu", cache_mode="none"
|
||||||
|
|
||||||
if len(batches) == 0:
|
|
||||||
pytest.skip("No batches created (dataset might be empty)")
|
|
||||||
|
|
||||||
# Columns should be identical
|
|
||||||
full_cols = set(full_dataset.columns)
|
|
||||||
batch_cols = set(batches[0].columns)
|
|
||||||
|
|
||||||
assert full_cols == batch_cols, (
|
|
||||||
f"Column mismatch between create() and create_batches().\n"
|
|
||||||
f"Only in create(): {full_cols - batch_cols}\n"
|
|
||||||
f"Only in create_batches(): {batch_cols - full_cols}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_training_dataset_feature_columns(self, sample_ensemble):
|
# Check that TrainingSet has all expected attributes
|
||||||
"""Test that create_cat_training_dataset creates proper feature columns."""
|
assert hasattr(training_set, "targets")
|
||||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
assert hasattr(training_set, "features")
|
||||||
|
assert hasattr(training_set, "X")
|
||||||
|
assert hasattr(training_set, "y")
|
||||||
|
assert hasattr(training_set, "z")
|
||||||
|
assert hasattr(training_set, "split")
|
||||||
|
|
||||||
# Get the columns used for model inputs
|
def test_training_set_feature_columns(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
model_input_cols = set(training_data.X.data.columns)
|
"""Test that create_training_set creates proper feature columns."""
|
||||||
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
|
task="binary", target="darts_v1", device="cpu", cache_mode="none"
|
||||||
|
)
|
||||||
|
|
||||||
# These columns should NOT be in model inputs
|
# Get the feature column names
|
||||||
assert "geometry" not in model_input_cols
|
feature_cols: set[str] = set(training_set.features.columns)
|
||||||
assert sample_ensemble.covcol not in model_input_cols
|
|
||||||
assert sample_ensemble.taskcol("binary") not in model_input_cols
|
|
||||||
|
|
||||||
# No darts_* columns should be in model inputs
|
# These columns should NOT be in features
|
||||||
for col in model_input_cols:
|
assert "geometry" not in feature_cols
|
||||||
assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs"
|
# Note: "y" and "x" are spatial coordinates from the grid and are valid features
|
||||||
|
# The target labels are stored in training_set.targets["y"], not in features
|
||||||
|
|
||||||
def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels):
|
def test_training_set_no_target_leakage(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test feature columns for mllabels target."""
|
"""Test that target information doesn't leak into features."""
|
||||||
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
|
task="binary", target="darts_v1", device="cpu", cache_mode="none"
|
||||||
|
)
|
||||||
|
|
||||||
model_input_cols = set(training_data.X.data.columns)
|
# Features should not contain target-related columns
|
||||||
|
feature_cols: pd.Index = training_set.features.columns
|
||||||
|
assert all(not col.startswith("darts_") for col in feature_cols)
|
||||||
|
|
||||||
# These columns should NOT be in model inputs
|
def test_inference_vs_training_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
assert "geometry" not in model_input_cols
|
"""Test that inference batches have the same features as training data.
|
||||||
assert sample_ensemble_mllabels.covcol not in model_input_cols
|
|
||||||
assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols
|
|
||||||
|
|
||||||
# No dartsml_* columns should be in model inputs
|
|
||||||
for col in model_input_cols:
|
|
||||||
assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs"
|
|
||||||
|
|
||||||
def test_inference_vs_training_feature_consistency(self, sample_ensemble):
|
|
||||||
"""Test that inference batches have the same features as training data after column dropping.
|
|
||||||
|
|
||||||
This test simulates the workflow in training.py and inference.py to ensure
|
This test simulates the workflow in training.py and inference.py to ensure
|
||||||
feature consistency between training and inference.
|
feature consistency between training and inference.
|
||||||
"""
|
"""
|
||||||
# Step 1: Create training dataset (as in training.py)
|
# Step 1: Create training set (as in training.py)
|
||||||
# Use only a small subset by creating just one batch
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
task="binary", target="darts_v1", device="cpu", cache_mode="none"
|
||||||
training_feature_cols = set(training_data.X.data.columns)
|
)
|
||||||
|
training_feature_cols: set[str] = set(training_set.features.columns)
|
||||||
|
|
||||||
# Step 2: Create inference batch (as in inference.py)
|
# Step 2: Create inference batch (as in inference.py)
|
||||||
# Get just the first batch to speed up test
|
batch_generator: Generator[pd.DataFrame] = sample_ensemble.create_inference_df(
|
||||||
batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n")
|
batch_size=100, cache_mode="none"
|
||||||
batch = next(batch_generator, None)
|
)
|
||||||
|
batch: pd.DataFrame | None = next(batch_generator, None)
|
||||||
|
|
||||||
if batch is None:
|
if batch is None:
|
||||||
pytest.skip("No batches created (dataset might be empty)")
|
pytest.skip("No batches created (dataset might be empty)")
|
||||||
|
|
||||||
# Simulate the column dropping in predict_proba
|
assert batch is not None # For type checker
|
||||||
cols_to_drop = ["geometry"]
|
inference_feature_cols: set[str] = set(batch.columns)
|
||||||
if sample_ensemble.target == "darts_mllabels":
|
|
||||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
|
||||||
else:
|
|
||||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
|
||||||
|
|
||||||
inference_batch = batch.drop(columns=cols_to_drop)
|
|
||||||
inference_feature_cols = set(inference_batch.columns)
|
|
||||||
|
|
||||||
# The features should match!
|
# The features should match!
|
||||||
assert training_feature_cols == inference_feature_cols, (
|
assert training_feature_cols == inference_feature_cols, (
|
||||||
|
|
@ -185,85 +207,64 @@ class TestDatasetEnsemble:
|
||||||
f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}"
|
f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels):
|
def test_all_tasks_feature_consistency(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test feature consistency for mllabels target."""
|
|
||||||
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
|
|
||||||
training_feature_cols = set(training_data.X.data.columns)
|
|
||||||
|
|
||||||
# Get just the first batch to speed up test
|
|
||||||
batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n")
|
|
||||||
batch = next(batch_generator, None)
|
|
||||||
|
|
||||||
if batch is None:
|
|
||||||
pytest.skip("No batches created (dataset might be empty)")
|
|
||||||
|
|
||||||
# Simulate the column dropping in predict_proba
|
|
||||||
cols_to_drop = ["geometry"]
|
|
||||||
if sample_ensemble_mllabels.target == "darts_mllabels":
|
|
||||||
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
|
|
||||||
else:
|
|
||||||
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
|
|
||||||
|
|
||||||
inference_batch = batch.drop(columns=cols_to_drop)
|
|
||||||
inference_feature_cols = set(inference_batch.columns)
|
|
||||||
|
|
||||||
assert training_feature_cols == inference_feature_cols, (
|
|
||||||
f"Feature mismatch between training and inference!\n"
|
|
||||||
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
|
|
||||||
f"Only in inference: {inference_feature_cols - training_feature_cols}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_all_tasks_feature_consistency(self, sample_ensemble):
|
|
||||||
"""Test that all task types produce consistent features."""
|
"""Test that all task types produce consistent features."""
|
||||||
tasks = ["binary", "count", "density"]
|
feature_sets: dict[str, set[str]] = {}
|
||||||
feature_sets = {}
|
|
||||||
|
|
||||||
for task in tasks:
|
for task in all_tasks:
|
||||||
training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu")
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
feature_sets[task] = set(training_data.X.data.columns)
|
task=task,
|
||||||
|
target="darts_v1",
|
||||||
|
device="cpu",
|
||||||
|
cache_mode="none",
|
||||||
|
)
|
||||||
|
feature_sets[task] = set(training_set.features.columns)
|
||||||
|
|
||||||
# All tasks should have the same features
|
# All tasks should have the same features
|
||||||
binary_features = feature_sets["binary"]
|
binary_features: set[str] = feature_sets["binary"]
|
||||||
for task, features in feature_sets.items():
|
for task, features in feature_sets.items():
|
||||||
assert features == binary_features, (
|
assert features == binary_features, (
|
||||||
f"Task '{task}' has different features than 'binary'.\n"
|
f"Task '{task}' has different features than 'binary'.\n"
|
||||||
f"Difference: {features.symmetric_difference(binary_features)}"
|
f"Difference: {features.symmetric_difference(binary_features)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_training_dataset_shapes(self, sample_ensemble):
|
def test_training_set_shapes(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that training dataset has correct shapes."""
|
"""Test that training set has correct shapes."""
|
||||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
|
task="binary", target="darts_v1", device="cpu", cache_mode="none"
|
||||||
|
)
|
||||||
|
|
||||||
n_features = len(training_data.X.data.columns)
|
n_features: int = len(training_set.features.columns)
|
||||||
n_samples_train = training_data.X.train.shape[0]
|
n_samples_train: int = training_set.X.train.shape[0]
|
||||||
n_samples_test = training_data.X.test.shape[0]
|
n_samples_test: int = training_set.X.test.shape[0]
|
||||||
|
|
||||||
# Check X shapes
|
# Check X shapes
|
||||||
assert training_data.X.train.shape == (n_samples_train, n_features)
|
assert training_set.X.train.shape == (n_samples_train, n_features)
|
||||||
assert training_data.X.test.shape == (n_samples_test, n_features)
|
assert training_set.X.test.shape == (n_samples_test, n_features)
|
||||||
|
|
||||||
# Check y shapes
|
# Check y shapes
|
||||||
assert training_data.y.train.shape == (n_samples_train,)
|
assert training_set.y.train.shape == (n_samples_train,)
|
||||||
assert training_data.y.test.shape == (n_samples_test,)
|
assert training_set.y.test.shape == (n_samples_test,)
|
||||||
|
|
||||||
# Check that train + test = total samples
|
# Check that train + test = total samples
|
||||||
assert len(training_data.dataset) == n_samples_train + n_samples_test
|
assert len(training_set) == n_samples_train + n_samples_test
|
||||||
|
|
||||||
def test_no_nan_in_training_features(self, sample_ensemble):
|
def test_no_nan_in_training_features(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that training features don't contain NaN values."""
|
"""Test that training features don't contain NaN values."""
|
||||||
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
|
task="binary", target="darts_v1", device="cpu", cache_mode="none"
|
||||||
|
)
|
||||||
|
|
||||||
# Convert to numpy for checking (handles both numpy and cupy arrays)
|
# Convert to numpy for checking (handles both numpy and cupy arrays)
|
||||||
X_train = np.asarray(training_data.X.train)
|
X_train: np.ndarray = np.asarray(training_set.X.train) # noqa: N806
|
||||||
X_test = np.asarray(training_data.X.test)
|
X_test: np.ndarray = np.asarray(training_set.X.test) # noqa: N806
|
||||||
|
|
||||||
assert not np.isnan(X_train).any(), "Training features contain NaN values"
|
assert not np.isnan(X_train).any(), "Training features contain NaN values"
|
||||||
assert not np.isnan(X_test).any(), "Test features contain NaN values"
|
assert not np.isnan(X_test).any(), "Test features contain NaN values"
|
||||||
|
|
||||||
def test_batch_coverage(self, sample_ensemble):
|
def test_batch_coverage(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that batches cover all data without duplication."""
|
"""Test that batches cover all data without duplication."""
|
||||||
full_dataset = sample_ensemble.create(cache_mode="n")
|
batches: list[pd.DataFrame] = list(sample_ensemble.create_inference_df(batch_size=100, cache_mode="none"))
|
||||||
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
|
|
||||||
|
|
||||||
if len(batches) == 0:
|
if len(batches) == 0:
|
||||||
pytest.skip("No batches created (dataset might be empty)")
|
pytest.skip("No batches created (dataset might be empty)")
|
||||||
|
|
@ -277,34 +278,91 @@ class TestDatasetEnsemble:
|
||||||
assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches"
|
assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches"
|
||||||
batch_cell_ids.update(batch_ids)
|
batch_cell_ids.update(batch_ids)
|
||||||
|
|
||||||
# Check that all cell_ids from full dataset are in batches
|
# Check that all cell_ids from grid are in batches
|
||||||
full_cell_ids = set(full_dataset.index)
|
all_cell_ids = set(sample_ensemble.cell_ids.to_numpy())
|
||||||
assert batch_cell_ids == full_cell_ids, (
|
assert batch_cell_ids == all_cell_ids, (
|
||||||
f"Batch coverage mismatch.\n"
|
f"Batch coverage mismatch.\n"
|
||||||
f"Missing from batches: {full_cell_ids - batch_cell_ids}\n"
|
f"Missing from batches: {all_cell_ids - batch_cell_ids}\n"
|
||||||
f"Extra in batches: {batch_cell_ids - full_cell_ids}"
|
f"Extra in batches: {batch_cell_ids - all_cell_ids}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestDatasetEnsembleEdgeCases:
|
class TestDatasetEnsembleEdgeCases:
|
||||||
"""Test edge cases and error handling."""
|
"""Test edge cases and error handling."""
|
||||||
|
|
||||||
def test_invalid_task_raises_error(self, sample_ensemble):
|
def test_invalid_task_raises_error(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that invalid task raises ValueError."""
|
"""Test that invalid task raises ValueError."""
|
||||||
with pytest.raises(ValueError, match="Invalid task"):
|
with pytest.raises(NotImplementedError, match=r"Task .* not supported"):
|
||||||
sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore
|
sample_ensemble.create_training_set(task="invalid", target="darts_v1", device="cpu") # type: ignore[arg-type]
|
||||||
|
|
||||||
def test_stats_method(self, sample_ensemble):
|
def test_target_labels_property(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
"""Test that get_stats returns expected structure."""
|
"""Test that TrainingSet provides target labels for categorical tasks."""
|
||||||
stats = sample_ensemble.get_stats()
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
|
task="binary", target="darts_v1", device="cpu", cache_mode="none"
|
||||||
|
)
|
||||||
|
|
||||||
assert "target" in stats
|
# Binary task should have labels
|
||||||
assert "num_target_samples" in stats
|
labels: list[str] | None = training_set.target_labels
|
||||||
assert "members" in stats
|
assert labels is not None
|
||||||
assert "total_features" in stats
|
assert isinstance(labels, list)
|
||||||
|
assert "No RTS" in labels
|
||||||
|
assert "RTS" in labels
|
||||||
|
|
||||||
# Check that members dict contains info for each member
|
def test_target_labels_none_for_regression(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
for member in sample_ensemble.members:
|
"""Test that regression tasks return None for target_labels."""
|
||||||
assert member in stats["members"]
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
assert "variables" in stats["members"][member]
|
task="count", target="darts_v1", device="cpu", cache_mode="none"
|
||||||
assert "num_features" in stats["members"][member]
|
)
|
||||||
|
|
||||||
|
# Regression task should return None
|
||||||
|
assert training_set.target_labels is None
|
||||||
|
|
||||||
|
def test_to_dataframe_method(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
|
"""Test that to_dataframe returns expected structure."""
|
||||||
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
|
task="binary", target="darts_v1", device="cpu", cache_mode="none"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get full dataset
|
||||||
|
df_full: pd.DataFrame = training_set.to_dataframe(split=None)
|
||||||
|
assert isinstance(df_full, pd.DataFrame)
|
||||||
|
assert "label" in df_full.columns
|
||||||
|
assert len(df_full) == len(training_set)
|
||||||
|
|
||||||
|
# Get train split
|
||||||
|
df_train = training_set.to_dataframe(split="train")
|
||||||
|
assert len(df_train) < len(df_full)
|
||||||
|
assert all(training_set.split[df_train.index] == "train")
|
||||||
|
|
||||||
|
# Get test split
|
||||||
|
df_test = training_set.to_dataframe(split="test")
|
||||||
|
assert len(df_test) < len(df_full)
|
||||||
|
assert all(training_set.split[df_test.index] == "test")
|
||||||
|
|
||||||
|
# Train + test should equal full
|
||||||
|
assert len(df_train) + len(df_test) == len(df_full)
|
||||||
|
|
||||||
|
def test_different_targets(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
|
"""Test that different target datasets can be used."""
|
||||||
|
for target in all_target_datasets:
|
||||||
|
try:
|
||||||
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
|
task="binary",
|
||||||
|
target=target,
|
||||||
|
device="cpu",
|
||||||
|
cache_mode="none",
|
||||||
|
)
|
||||||
|
assert len(training_set) > 0
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Target {target} not available: {e}")
|
||||||
|
|
||||||
|
def test_feature_names_property(self, sample_ensemble: DatasetEnsemble) -> None:
|
||||||
|
"""Test that feature_names property returns list of feature names."""
|
||||||
|
training_set: TrainingSet = sample_ensemble.create_training_set(
|
||||||
|
task="binary", target="darts_v1", device="cpu", cache_mode="none"
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_names: list[str] = training_set.feature_names
|
||||||
|
assert isinstance(feature_names, list)
|
||||||
|
assert len(feature_names) == training_set.features.shape[1]
|
||||||
|
assert all(isinstance(name, str) for name in feature_names)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue