#!/usr/bin/env python """Rerun inference for training results that are missing predicted probabilities. This script searches through training result directories and identifies those that have a trained model but are missing inference results. It then loads the model and dataset configuration, reruns inference, and saves the results. """ import pickle from pathlib import Path import toml from rich.console import Console from rich.progress import track from entropice.ml.dataset import DatasetEnsemble from entropice.ml.inference import predict_proba from entropice.utils.paths import RESULTS_DIR console = Console() def find_incomplete_trainings() -> list[Path]: """Find training result directories missing inference results. Returns: list[Path]: List of directories with trained models but missing predictions. """ incomplete = [] if not RESULTS_DIR.exists(): console.print(f"[yellow]Results directory not found: {RESULTS_DIR}[/yellow]") return incomplete # Search for all training result directories for result_dir in RESULTS_DIR.glob("*_cv*"): if not result_dir.is_dir(): continue model_file = result_dir / "best_estimator_model.pkl" settings_file = result_dir / "search_settings.toml" predictions_file = result_dir / "predicted_probabilities.parquet" # Check if model and settings exist but predictions are missing if model_file.exists() and settings_file.exists() and not predictions_file.exists(): incomplete.append(result_dir) return incomplete def rerun_inference(result_dir: Path) -> bool: """Rerun inference for a training result directory. Args: result_dir (Path): Path to the training result directory. Returns: bool: True if successful, False otherwise. """ try: console.print(f"\n[cyan]Processing: {result_dir.name}[/cyan]") # Load settings settings_file = result_dir / "search_settings.toml" with open(settings_file) as f: settings_data = toml.load(f) settings = settings_data["settings"] # Reconstruct DatasetEnsemble from settings ensemble = DatasetEnsemble( grid=settings["grid"], level=settings["level"], target=settings["target"], members=settings["members"], dimension_filters=settings.get("dimension_filters", {}), variable_filters=settings.get("variable_filters", {}), filter_target=settings.get("filter_target", False), add_lonlat=settings.get("add_lonlat", True), ) # Load trained model model_file = result_dir / "best_estimator_model.pkl" with open(model_file, "rb") as f: clf = pickle.load(f) console.print("[green]✓[/green] Loaded model and settings") # Get class labels classes = settings["classes"] # Run inference console.print("[yellow]Running inference...[/yellow]") preds = predict_proba(ensemble, clf=clf, classes=classes) # Save predictions preds_file = result_dir / "predicted_probabilities.parquet" preds.to_parquet(preds_file) console.print(f"[green]✓[/green] Saved {len(preds)} predictions to {preds_file.name}") return True except Exception as e: console.print(f"[red]✗ Error processing {result_dir.name}: {e}[/red]") import traceback console.print(f"[red]{traceback.format_exc()}[/red]") return False def main(): """Rerun missing inferences for incomplete training results.""" console.print("[bold blue]Searching for incomplete training results...[/bold blue]") incomplete_dirs = find_incomplete_trainings() if not incomplete_dirs: console.print("[green]No incomplete trainings found. All trainings have predictions![/green]") return console.print(f"[yellow]Found {len(incomplete_dirs)} training(s) missing predictions:[/yellow]") for d in incomplete_dirs: console.print(f" • {d.name}") console.print(f"\n[bold]Processing {len(incomplete_dirs)} training result(s)...[/bold]\n") successful = 0 failed = 0 for result_dir in track(incomplete_dirs, description="Rerunning inference"): if rerun_inference(result_dir): successful += 1 else: failed += 1 console.print("\n[bold]Summary:[/bold]") console.print(f" [green]Successful: {successful}[/green]") console.print(f" [red]Failed: {failed}[/red]") if __name__ == "__main__": main()