- Added a section to display test metrics for model performance on the held-out test set. - Implemented confusion matrix visualization to analyze prediction breakdown. - Refactored sidebar settings to streamline metric selection and improve user experience. - Updated cross-validation statistics to compare CV performance with test metrics. - Enhanced DatasetEnsemble methods to handle empty data scenarios gracefully. - Introduced debug scripts to assist in identifying feature mismatches and validating dataset preparation. - Added comprehensive tests for DatasetEnsemble to ensure feature consistency and correct behavior across various scenarios.
144 lines
4.5 KiB
Python
144 lines
4.5 KiB
Python
#!/usr/bin/env python
|
|
"""Rerun inference for training results that are missing predicted probabilities.
|
|
|
|
This script searches through training result directories and identifies those that have
|
|
a trained model but are missing inference results. It then loads the model and dataset
|
|
configuration, reruns inference, and saves the results.
|
|
"""
|
|
|
|
import pickle
|
|
from pathlib import Path
|
|
|
|
import toml
|
|
from rich.console import Console
|
|
from rich.progress import track
|
|
|
|
from entropice.ml.dataset import DatasetEnsemble
|
|
from entropice.ml.inference import predict_proba
|
|
from entropice.utils.paths import RESULTS_DIR
|
|
|
|
console = Console()
|
|
|
|
|
|
def find_incomplete_trainings() -> list[Path]:
|
|
"""Find training result directories missing inference results.
|
|
|
|
Returns:
|
|
list[Path]: List of directories with trained models but missing predictions.
|
|
|
|
"""
|
|
incomplete = []
|
|
|
|
if not RESULTS_DIR.exists():
|
|
console.print(f"[yellow]Results directory not found: {RESULTS_DIR}[/yellow]")
|
|
return incomplete
|
|
|
|
# Search for all training result directories
|
|
for result_dir in RESULTS_DIR.glob("*_cv*"):
|
|
if not result_dir.is_dir():
|
|
continue
|
|
|
|
model_file = result_dir / "best_estimator_model.pkl"
|
|
settings_file = result_dir / "search_settings.toml"
|
|
predictions_file = result_dir / "predicted_probabilities.parquet"
|
|
|
|
# Check if model and settings exist but predictions are missing
|
|
if model_file.exists() and settings_file.exists() and not predictions_file.exists():
|
|
incomplete.append(result_dir)
|
|
|
|
return incomplete
|
|
|
|
|
|
def rerun_inference(result_dir: Path) -> bool:
|
|
"""Rerun inference for a training result directory.
|
|
|
|
Args:
|
|
result_dir (Path): Path to the training result directory.
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise.
|
|
|
|
"""
|
|
try:
|
|
console.print(f"\n[cyan]Processing: {result_dir.name}[/cyan]")
|
|
|
|
# Load settings
|
|
settings_file = result_dir / "search_settings.toml"
|
|
with open(settings_file) as f:
|
|
settings_data = toml.load(f)
|
|
|
|
settings = settings_data["settings"]
|
|
|
|
# Reconstruct DatasetEnsemble from settings
|
|
ensemble = DatasetEnsemble(
|
|
grid=settings["grid"],
|
|
level=settings["level"],
|
|
target=settings["target"],
|
|
members=settings["members"],
|
|
dimension_filters=settings.get("dimension_filters", {}),
|
|
variable_filters=settings.get("variable_filters", {}),
|
|
filter_target=settings.get("filter_target", False),
|
|
add_lonlat=settings.get("add_lonlat", True),
|
|
)
|
|
|
|
# Load trained model
|
|
model_file = result_dir / "best_estimator_model.pkl"
|
|
with open(model_file, "rb") as f:
|
|
clf = pickle.load(f)
|
|
|
|
console.print("[green]✓[/green] Loaded model and settings")
|
|
|
|
# Get class labels
|
|
classes = settings["classes"]
|
|
|
|
# Run inference
|
|
console.print("[yellow]Running inference...[/yellow]")
|
|
preds = predict_proba(ensemble, clf=clf, classes=classes)
|
|
|
|
# Save predictions
|
|
preds_file = result_dir / "predicted_probabilities.parquet"
|
|
preds.to_parquet(preds_file)
|
|
|
|
console.print(f"[green]✓[/green] Saved {len(preds)} predictions to {preds_file.name}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
console.print(f"[red]✗ Error processing {result_dir.name}: {e}[/red]")
|
|
import traceback
|
|
|
|
console.print(f"[red]{traceback.format_exc()}[/red]")
|
|
return False
|
|
|
|
|
|
def main():
|
|
"""Rerun missing inferences for incomplete training results."""
|
|
console.print("[bold blue]Searching for incomplete training results...[/bold blue]")
|
|
|
|
incomplete_dirs = find_incomplete_trainings()
|
|
|
|
if not incomplete_dirs:
|
|
console.print("[green]No incomplete trainings found. All trainings have predictions![/green]")
|
|
return
|
|
|
|
console.print(f"[yellow]Found {len(incomplete_dirs)} training(s) missing predictions:[/yellow]")
|
|
for d in incomplete_dirs:
|
|
console.print(f" • {d.name}")
|
|
|
|
console.print(f"\n[bold]Processing {len(incomplete_dirs)} training result(s)...[/bold]\n")
|
|
|
|
successful = 0
|
|
failed = 0
|
|
|
|
for result_dir in track(incomplete_dirs, description="Rerunning inference"):
|
|
if rerun_inference(result_dir):
|
|
successful += 1
|
|
else:
|
|
failed += 1
|
|
|
|
console.print("\n[bold]Summary:[/bold]")
|
|
console.print(f" [green]Successful: {successful}[/green]")
|
|
console.print(f" [red]Failed: {failed}[/red]")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|