From c92e856c557952d5059e98e4a847dce088378525 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Wed, 7 Jan 2026 15:56:02 +0100 Subject: [PATCH] Enhance training analysis page with test metrics and confusion matrix - Added a section to display test metrics for model performance on the held-out test set. - Implemented confusion matrix visualization to analyze prediction breakdown. - Refactored sidebar settings to streamline metric selection and improve user experience. - Updated cross-validation statistics to compare CV performance with test metrics. - Enhanced DatasetEnsemble methods to handle empty data scenarios gracefully. - Introduced debug scripts to assist in identifying feature mismatches and validating dataset preparation. - Added comprehensive tests for DatasetEnsemble to ensure feature consistency and correct behavior across various scenarios. --- pixi.lock | 223 +++++---- pyproject.toml | 2 +- scripts/recalculate_test_metrics.py | 195 ++++++++ scripts/rechunk_zarr.py | 58 +++ scripts/rerun_missing_inference.py | 144 ++++++ .../plots/hyperparameter_analysis.py | 448 +++++++++++++----- src/entropice/dashboard/plots/inference.py | 33 +- src/entropice/dashboard/plots/model_state.py | 2 +- src/entropice/dashboard/plots/overview.py | 92 ++-- .../dashboard/utils/class_ordering.py | 70 +++ src/entropice/dashboard/utils/geometry.py | 12 +- src/entropice/dashboard/utils/loaders.py | 15 +- .../dashboard/views/inference_page.py | 1 - .../dashboard/views/overview_page.py | 282 +++++------ .../dashboard/views/training_analysis_page.py | 156 ++++-- .../dashboard/views/training_data_page.py | 1 + src/entropice/ml/dataset.py | 100 +++- src/entropice/ml/inference.py | 5 + src/entropice/ml/training.py | 61 +++ src/entropice/utils/types.py | 14 + tests/debug_arcticdem_batch.py | 33 ++ tests/debug_feature_mismatch.py | 72 +++ tests/test_dataset.py | 310 ++++++++++++ 23 files changed, 1845 insertions(+), 484 deletions(-) create mode 100644 scripts/recalculate_test_metrics.py create mode 100644 scripts/rechunk_zarr.py create mode 100644 scripts/rerun_missing_inference.py create mode 100644 src/entropice/dashboard/utils/class_ordering.py create mode 100644 tests/debug_arcticdem_batch.py create mode 100644 tests/debug_feature_mismatch.py create mode 100644 tests/test_dataset.py diff --git a/pixi.lock b/pixi.lock index 8ccac5f..ff5881d 100644 --- a/pixi.lock +++ b/pixi.lock @@ -473,16 +473,16 @@ environments: - pypi: https://files.pythonhosted.org/packages/69/ce/68d6e31f0a75a5cccc03535e47434c0ca4be37fe950e93117e455cbc362c/antimeridian-0.4.5-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5b/03/c17464bbf682ea87e7e3de2ddc63395e359a78ae9c01f55fc78759ecbd79/anywidget-0.9.21-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/1d/05/2709750ddb088eb2fc5053ba214b4f54334d15d4cb28217e2956b5507bac/array_api_extra-0.9.1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/df/5d/493b1b5528ab5072feae30821ff3a07b7a0474213d548efb1fdf135f85c1/array_api_compat-1.13.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/21/2b/bfa1cfe370dd4ed51f834f2c6ad93b7f6263b83615ab96ad91094cc98ec6/array_api_extra-0.9.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fb/1f/2903ef412cb82ba1f2211692f7339fd7c5aeb2764f2a97f0b6a9a18bbf52/arro3_compute-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/31/4a/72dc383d1a0d14f1d453e334e3461e229762edb1bf3f75b3ab977e9386ed/arro3_core-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/1b/df/2a5a1306dc1699b51b02c1c38c55f3564a8c4f84087c23c61e7e7ae37dfa/arro3_io-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/c3/1c/f06ad85180e7dd9855aa5ede901bfc2be858d7bee17d4e978a14c0ecec14/astropy-7.2.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/57/61/2d06c08f022c9b617b79f6c55d88e596c1795a1d211e6bf584ac4b9e9506/astropy_iers_data-0.2026.1.5.0.43.43-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ee/34/a9914e676971a13d6cc671b1ed172f9804b50a3a80a143ff196e52f4c7ee/azure_core-1.37.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d8/3a/6ef2047a072e54e1142718d433d50e9514c999a58f51abfff7902f3a72f8/azure_storage_blob-12.28.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/3c/56/f47a80254ed4991cce9a2f6d8ae8aafbc8df1c3270e966b2927289e5a12f/boto3-1.41.5-py3-none-any.whl @@ -494,12 +494,11 @@ environments: - pypi: https://files.pythonhosted.org/packages/27/27/6414b1b7e5e151300c54e28ad1cf3e3b34fe66dc3256a989b031166b1ba3/cdshealpix-0.7.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/a3/8f/c42a98f933022c7de00142526c9b6b7429fdcd0fc66c952b4ebbf0ff3b7f/cf_xarray-0.10.10-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ba/08/52f06ff2f04d376f9cd2c211aefcf2b37f1978e43289341f362fc99f6a0e/cftime-1.6.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/3d/9a/2abecb28ae875e39c8cad711eb1186d8d14eab564705325e77e4e6ab9ae5/click_plugins-1.1.1.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/73/86/43fa9f15c5b9fb6e82620428827cd3c284aa933431405d1bcf5231ae3d3e/cligj-0.7.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/fa/25/0be9314cd72fe2ee2ef89ceb1f438bc156428a12177d684040456eee4a56/cupy_xarray-0.1.4-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/20/5b/0eceb9a5990de9025733a0d212ca43649ba9facd58b8552b6bf93c11439d/cyclopts-4.4.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/25/3e/e27078370414ef35fafad2c06d182110073daaeb5d3bf734b0b1eeefe452/debugpy-1.8.19-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl @@ -522,10 +521,10 @@ environments: - pypi: https://files.pythonhosted.org/packages/31/b3/802576f2ea5dcb48501bb162e4c7b7b3ca5654a42b2c968ef98a797a4c79/geographiclib-2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e5/15/cf2a69ade4b194aa524ac75112d5caac37414b20a3a03e6865dfe0bd1539/geopy-2.4.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ed/d4/90197b416cb61cefd316964fd9e7bd8324bcbafabf40eef14a9f20b81974/google_api_core-2.28.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/96/58/c1e716be1b055b504d80db2c8413f6c6a890a6ae218a65f178b63bc30356/google_api_python_client-2.187.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/c6/97/451d55e05487a5cd6279a01a7e34921858b16f7dc8aa38a2c684743cd2b3/google_auth-2.45.0-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/db/18/79e9008530b79527e0d5f79e7eef08d3b179b7f851cfd3a2f27822fbdfa9/google_auth-2.47.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/d5/3c97526c8796d3caf5f4b3bed2b05e8a7102326f00a334e7a438237f3b22/google_auth_httplib2-0.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/89/20/bfa472e327c8edee00f04beecc80baeddd2ab33ee0e86fd7654da49d45e9/google_cloud_core-2.5.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2d/80/6e5c7c83cea15ed4dfc4843b9df9db0716bc551ac938f7b5dd18a72bd5e4/google_cloud_storage-3.7.0-py3-none-any.whl @@ -537,13 +536,14 @@ environments: - pypi: https://files.pythonhosted.org/packages/d6/49/1f35189c1ca136b2f041b72402f2eb718bdcb435d9e88729fe6f6909c45d/h5netcdf-1.7.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d9/69/4402ea66272dacc10b298cca18ed73e1c0791ff2ae9ed218d3859f9698ac/h5py-3.15.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8c/a2/0d269db0f6163be503775dc8b6a6fa15820cc9fdc866f6ba608d86b721f2/httplib2-0.31.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/94/56/c5e8db63ba0e27b310a0b4c384da555b361741e7d186044d31f400c0419e/icechunk-1.1.14-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/8c/d7/db466e07a21553441adbf915f0913a3f8fecece364cacb2392f11be267be/icechunk-1.1.15-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/0f/b66d63d4a5426c09005d3713b056e634e00e69788fdc88d1ffe40e5b7654/ipycytoscape-1.3.3-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ca/d3/642a6dc3db8ea558a9b5fbc83815b197861868dc98f98a789b85c7660670/ipyevents-2.0.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/00/60/249e3444fcd9c833704741769981cd02fe2c7ce94126b1394e7a3b26e543/ipyfilechooser-0.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/69/e9858f2c0b99bf9f036348d1c84b8026f438bb6875effe6a9bcd9883dada/ipyleaflet-0.20.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/86/92/162cfaee4ccf370465c5af1ce36a9eacec1becb552f2033bb3584e6f640a/ipython-9.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/56/6d/0d9848617b9f753b87f214f1c682592f7ca42de085f564352f10f0843026/ipywidgets-8.1.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/15/aa/0aca39a37d3c7eb941ba736ede56d689e7be91cab5d9ca846bde3999eba6/isodate-0.7.2-py3-none-any.whl @@ -558,11 +558,11 @@ environments: - pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/97/1a/78b19893197ed7525edfa7f124a461626541e82aec694a468ba97755c24e/netcdf4-1.7.3-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/7b/7a/a8d32501bb95ecff342004a674720164f95ad616f269450b3bc13dc88ae3/netcdf4-1.7.4-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/ae/d3/ff8f1b9968aa4dcd1da1880322ed492314cc920998182e549b586c895a17/numbagg-0.9.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c4/e6/d359fdd37498e74d26a167f7a51e54542e642ea47181eb4e643a69a066c3/numcodecs-0.16.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b0/e0/760e73c111193db5ca37712a148e4807d1b0c60302ab31e4ada6528ca34d/numpy_groupies-0.11.3-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/4a/4e/44dbb46b3d1b0ec61afda8e84837870f2f9ace33c564317d59b70bc19d3e/nvidia_nccl_cu12-2.28.9-py3-none-manylinux_2_18_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/23/2d/609d0392d992259c6dc39881688a7fc13b1397a668bc360fbd68d1396f85/nvidia_nccl_cu12-2.29.2-py3-none-manylinux_2_18_x86_64.whl - pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/99/e2/311fb383d9534eef7bfbe858fad931b6e3dbe85843c50592f50063c3bc83/odc_geo-0.4.10-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/84/99/6636f7097a5e461d560317024522279f52931b5a52c8caa0755a14d5f1fd/odc_loader-0.6.0-py3-none-any.whl @@ -572,11 +572,12 @@ environments: - pypi: https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e7/c3/3031c931098de393393e1f93a38dc9ed6805d86bb801acc3cf2d5bd1e6b7/plotly-6.5.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/cd/24/3b7a0818484df9c28172857af32c2397b6d8fcd99d9468bd4684f98ebf0a/proto_plus-1.27.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/56/13/333b8f421738f149d4fe5e49553bc2a2ab75235486259f689b4b91f96cec/protobuf-6.33.2-cp39-abi3-manylinux2014_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/ff/7b/e9a6fa461ef266c5a23485004934b8f08a2a8ddc447802161ea56d9837dd/psygnal-0.15.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/2d/4f/3593e5adb88a188c798604aed95fbc1479f30230e7f51e8f2c770e6a3832/psygnal-0.15.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl @@ -587,8 +588,9 @@ environments: - pypi: https://files.pythonhosted.org/packages/82/06/cad54e8ce758bd836ee5411691cbd49efeb9cc611b374670fce299519334/pyshp-3.0.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/9f/86/3ec01436c6235a23a80e978b261a87481c1acaf626a5c618e9edac30e5e1/pystac-1.14.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5d/d2/5f6367b14c9f250d1a6725d18bd1e9584f5ab1587e292f3a847e59189598/pystac_client-0.9.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/ae/baf3a8057d8129896a7e02619df43ea0d918fc5b2bb66eb6e2470595fbac/python_box-7.3.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/7b/84/66c0d9cca2a09074ec2ce6fffa87709ca51b0d197ae742d835e841bac660/rasterio-1.4.4-cp313-cp313-manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/48/4a/1af9aa9810fb30668568f2c4dd3eec2412c8e9762b69201d971c509b295e/rasterio-1.5.0-cp313-cp313-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/f2/98/7e6d147fd16a10a5f821db6e25f192265d6ecca3d82957a4fdd592cad49c/ratelim-0.1.6-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/34/83/a485250bc09db55e4b4389d99e583fac871ceeaaa4620b67a31d8db95ef5/rechunker-0.5.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/13/2f/b4530fbf948867702d0a3f27de4a6aab1d156f406d72852ab902c4d04de9/rich_rst-1.3.2-py3-none-any.whl @@ -608,16 +610,16 @@ environments: - pypi: https://files.pythonhosted.org/packages/c0/95/6b7873f0267973ebd55ba9cd33a690b35a116f2779901ef6185a0e21864d/streamlit-1.52.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/b5/fc/5e2988590ff2e0128eea6446806c904445a44e17256c67141573ea16b5a5/textual-6.11.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/84/38/47fab2a5fad163ca4851f7a20eb2442491cc63bf2756ec4ef161bc1461dd/textual-7.0.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/94/fc/1d34ec891900d9337169ff9f8252fcaa633ae5c4d36b67effd849ed4f9ac/ty-0.0.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/95/20/92e3083b0e854943015bc8a7866e284ead9efadf9bf6809e6fce3b7ded61/ultraplot-1.66.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/43/6c/b26831b890b37c09882f6406efd31441c8e512bf1efbc967b9d867c5e02b/ultraplot-1.70.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a9/99/3ae339466c9183ea5b8ae87b34c0b897eda475d2aec2307cae60e5cd4f29/uritemplate-4.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e6/9f/ca52771fe972e0dcc5167fedb609940e01516066938ff2ee28b273ae4f29/vega_datasets-0.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/04/d5/81d1403788f072e7d0e2b2fe539a0ae4410f27886ff52df094e5348c99ea/vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/6f/61/dc6f4a38cf1b8699f64c57d7f021ca42c39bfe782d8a6eaefb7e8418e925/vl_convert_python-1.9.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl @@ -884,10 +886,10 @@ packages: - pkg:pypi/argon2-cffi-bindings?source=hash-mapping size: 35943 timestamp: 1762509452935 -- pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/df/5d/493b1b5528ab5072feae30821ff3a07b7a0474213d548efb1fdf135f85c1/array_api_compat-1.13.0-py3-none-any.whl name: array-api-compat - version: 1.12.0 - sha256: a0b4795b6944a9507fde54679f9350e2ad2b1e2acf4a2408a098cdc27f890a8b + version: 1.13.0 + sha256: c15026a0ddec42815383f07da285472e1b1ff2e632eb7afbcfe9b08fcbad9bf1 requires_dist: - cupy ; extra == 'cupy' - dask>=2024.9.0 ; extra == 'dask' @@ -905,16 +907,16 @@ packages: - array-api-strict ; extra == 'dev' - dask[array]>=2024.9.0 ; extra == 'dev' - jax[cpu] ; extra == 'dev' + - ndonnx ; extra == 'dev' - numpy>=1.22 ; extra == 'dev' - pytest ; extra == 'dev' - torch ; extra == 'dev' - sparse>=0.15.1 ; extra == 'dev' - - ndonnx ; extra == 'dev' requires_python: '>=3.10' -- pypi: https://files.pythonhosted.org/packages/1d/05/2709750ddb088eb2fc5053ba214b4f54334d15d4cb28217e2956b5507bac/array_api_extra-0.9.1-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/21/2b/bfa1cfe370dd4ed51f834f2c6ad93b7f6263b83615ab96ad91094cc98ec6/array_api_extra-0.9.2-py3-none-any.whl name: array-api-extra - version: 0.9.1 - sha256: 78b3e6605d1cdc9a66bb49e340e1bb620f045f1809a4e146d74500c3cb813b74 + version: 0.9.2 + sha256: d0643a9a4e981746057649accad068ca0fe4066d890f6a95d8b4cd5131b3b661 requires_dist: - array-api-compat>=1.12.0,<2 requires_python: '>=3.10' @@ -1026,10 +1028,10 @@ packages: - astropy[dev] ; extra == 'dev-all' - astropy[test-all] ; extra == 'dev-all' requires_python: '>=3.11' -- pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/57/61/2d06c08f022c9b617b79f6c55d88e596c1795a1d211e6bf584ac4b9e9506/astropy_iers_data-0.2026.1.5.0.43.43-py3-none-any.whl name: astropy-iers-data - version: 0.2025.12.22.0.40.30 - sha256: 2fbc71988d96aa29566667c6568a2bc5ca00748174b1f8ac3e9f7b09d4c27cac + version: 0.2026.1.5.0.43.43 + sha256: fe2c35e9abc99142083d717ea76bf7bde373dc12e502aaeced28ae4ff9bfc345 requires_dist: - pytest ; extra == 'docs' - hypothesis ; extra == 'test' @@ -1298,10 +1300,10 @@ packages: purls: [] size: 249684 timestamp: 1761066654684 -- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/d8/3a/6ef2047a072e54e1142718d433d50e9514c999a58f51abfff7902f3a72f8/azure_storage_blob-12.28.0-py3-none-any.whl name: azure-storage-blob - version: 12.27.1 - sha256: 65d1e25a4628b7b6acd20ff7902d8da5b4fde8e46e19c8f6d213a3abc3ece272 + version: 12.28.0 + sha256: 00fb1db28bf6a7b7ecaa48e3b1d5c83bfadacc5a678b77826081304bd87d6461 requires_dist: - azure-core>=1.30.0 - cryptography>=2.1.4 @@ -1772,16 +1774,6 @@ packages: - pkg:pypi/click?source=hash-mapping size: 97676 timestamp: 1764518652276 -- pypi: https://files.pythonhosted.org/packages/3d/9a/2abecb28ae875e39c8cad711eb1186d8d14eab564705325e77e4e6ab9ae5/click_plugins-1.1.1.2-py2.py3-none-any.whl - name: click-plugins - version: 1.1.1.2 - sha256: 008d65743833ffc1f5417bf0e78e8d2c23aab04d9745ba817bd3e71b0feb6aa6 - requires_dist: - - click>=4.0 - - pytest>=3.6 ; extra == 'dev' - - pytest-cov ; extra == 'dev' - - wheel ; extra == 'dev' - - coveralls ; extra == 'dev' - pypi: https://files.pythonhosted.org/packages/73/86/43fa9f15c5b9fb6e82620428827cd3c284aa933431405d1bcf5231ae3d3e/cligj-0.7.2-py3-none-any.whl name: cligj version: 0.7.2 @@ -2517,10 +2509,10 @@ packages: - pkg:pypi/cycler?source=hash-mapping size: 14778 timestamp: 1764466758386 -- pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/20/5b/0eceb9a5990de9025733a0d212ca43649ba9facd58b8552b6bf93c11439d/cyclopts-4.4.4-py3-none-any.whl name: cyclopts - version: 4.4.1 - sha256: 67500e9fde90f335fddbf9c452d2e7c4f58209dffe52e7abb1e272796a963bde + version: 4.4.4 + sha256: 316f798fe2f2a30cb70e7140cfde2a46617bfbb575d31bbfdc0b2410a447bd83 requires_dist: - attrs>=23.1.0 - docstring-parser>=0.15,<4.0 @@ -2869,7 +2861,7 @@ packages: - pypi: ./ name: entropice version: 0.1.0 - sha256: cb0c27d2c23c64d7533c03e380cf55c40e82a4d52a0392a829fad06a4ca93736 + sha256: c15584d2588d1f67ff2c69d6d3afa461fd1b9571d423497221dcb295dd7b1514 requires_dist: - aiohttp>=3.12.11 - bokeh>=3.7.3 @@ -2932,6 +2924,7 @@ packages: - ty>=0.0.2,<0.0.3 - ruff>=0.14.9,<0.15 - pandas-stubs>=2.3.3.251201,<3 + - pytest>=9.0.2,<10 requires_python: '>=3.13,<3.14' - pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7 name: entropy @@ -3422,17 +3415,17 @@ packages: requires_dist: - smmap>=3.0.1,<6 requires_python: '>=3.7' -- pypi: https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl name: gitpython - version: 3.1.45 - sha256: 8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77 + version: 3.1.46 + sha256: 79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058 requires_dist: - gitdb>=4.0.1,<5 - typing-extensions>=3.10.0.2 ; python_full_version < '3.10' - coverage[toml] ; extra == 'test' - ddt>=1.1.1,!=1.4.3 ; extra == 'test' - mock ; python_full_version < '3.8' and extra == 'test' - - mypy ; extra == 'test' + - mypy==1.18.2 ; python_full_version >= '3.9' and extra == 'test' - pre-commit ; extra == 'test' - pytest>=7.3.1 ; extra == 'test' - pytest-cov ; extra == 'test' @@ -3516,42 +3509,35 @@ packages: - google-api-core>=1.31.5,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0 - uritemplate>=3.0.1,<5 requires_python: '>=3.7' -- pypi: https://files.pythonhosted.org/packages/c6/97/451d55e05487a5cd6279a01a7e34921858b16f7dc8aa38a2c684743cd2b3/google_auth-2.45.0-py2.py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/db/18/79e9008530b79527e0d5f79e7eef08d3b179b7f851cfd3a2f27822fbdfa9/google_auth-2.47.0-py3-none-any.whl name: google-auth - version: 2.45.0 - sha256: 82344e86dc00410ef5382d99be677c6043d72e502b625aa4f4afa0bdacca0f36 + version: 2.47.0 + sha256: c516d68336bfde7cf0da26aab674a36fedcf04b37ac4edd59c597178760c3498 requires_dist: - - cachetools>=2.0.0,<7.0 - pyasn1-modules>=0.2.1 - rsa>=3.1.4,<5 - cryptography>=38.0.3 ; extra == 'cryptography' - - cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'cryptography' - aiohttp>=3.6.2,<4.0.0 ; extra == 'aiohttp' - requests>=2.20.0,<3.0.0 ; extra == 'aiohttp' - cryptography ; extra == 'enterprise-cert' - pyopenssl ; extra == 'enterprise-cert' - pyopenssl>=20.0.0 ; extra == 'pyopenssl' - cryptography>=38.0.3 ; extra == 'pyopenssl' - - cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'pyopenssl' - pyjwt>=2.0 ; extra == 'pyjwt' - cryptography>=38.0.3 ; extra == 'pyjwt' - - cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'pyjwt' - pyu2f>=0.1.5 ; extra == 'reauth' - requests>=2.20.0,<3.0.0 ; extra == 'requests' - grpcio ; extra == 'testing' - flask ; extra == 'testing' - freezegun ; extra == 'testing' - - mock ; extra == 'testing' - oauth2client ; extra == 'testing' - pyjwt>=2.0 ; extra == 'testing' - cryptography>=38.0.3 ; extra == 'testing' - - cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'testing' - pytest ; extra == 'testing' - pytest-cov ; extra == 'testing' - pytest-localserver ; extra == 'testing' - pyopenssl>=20.0.0 ; extra == 'testing' - cryptography>=38.0.3 ; extra == 'testing' - - cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'testing' - pyu2f>=0.1.5 ; extra == 'testing' - responses ; extra == 'testing' - urllib3 ; extra == 'testing' @@ -3564,7 +3550,7 @@ packages: - aiohttp<3.10.0 ; extra == 'testing' - urllib3 ; extra == 'urllib3' - packaging ; extra == 'urllib3' - requires_python: '>=3.7' + requires_python: '>=3.8' - pypi: https://files.pythonhosted.org/packages/99/d5/3c97526c8796d3caf5f4b3bed2b05e8a7102326f00a334e7a438237f3b22/google_auth_httplib2-0.3.0-py3-none-any.whl name: google-auth-httplib2 version: 0.3.0 @@ -3767,10 +3753,10 @@ packages: - pkg:pypi/hyperframe?source=hash-mapping size: 17397 timestamp: 1737618427549 -- pypi: https://files.pythonhosted.org/packages/94/56/c5e8db63ba0e27b310a0b4c384da555b361741e7d186044d31f400c0419e/icechunk-1.1.14-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +- pypi: https://files.pythonhosted.org/packages/8c/d7/db466e07a21553441adbf915f0913a3f8fecece364cacb2392f11be267be/icechunk-1.1.15-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl name: icechunk - version: 1.1.14 - sha256: adb01a0275144c58f741b5402e658930326e86f7b389e879065e01625c021f7c + version: 1.1.15 + sha256: c9e0cc3c8623a48861470553dbb8b0f1e86600989f597ce41ecf47568d8d099d requires_dist: - zarr>=3,!=3.0.3 - boto3 ; extra == 'test' @@ -3926,6 +3912,11 @@ packages: - pkg:pypi/importlib-metadata?source=hash-mapping size: 34641 timestamp: 1747934053147 +- pypi: https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl + name: iniconfig + version: 2.3.0 + sha256: f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/4c/0f/b66d63d4a5426c09005d3713b056e634e00e69788fdc88d1ffe40e5b7654/ipycytoscape-1.3.3-py2.py3-none-any.whl name: ipycytoscape version: 1.3.3 @@ -4024,10 +4015,10 @@ packages: - traittypes>=0.2.1,<3 - xyzservices>=2021.8.1 requires_python: '>=3.8' -- pypi: https://files.pythonhosted.org/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/86/92/162cfaee4ccf370465c5af1ce36a9eacec1becb552f2033bb3584e6f640a/ipython-9.9.0-py3-none-any.whl name: ipython - version: 9.8.0 - sha256: ebe6d1d58d7d988fbf23ff8ff6d8e1622cfdb194daf4b7b73b792c4ec3b85385 + version: 9.9.0 + sha256: b457fe9165df2b84e8ec909a97abcf2ed88f565970efba16b1f7229c283d252b requires_dist: - colorama>=0.4.4 ; sys_platform == 'win32' - decorator>=4.3.2 @@ -4067,7 +4058,8 @@ packages: - pandas>2.1 ; extra == 'test-extra' - trio>=0.1.0 ; extra == 'test-extra' - matplotlib>3.9 ; extra == 'matplotlib' - - ipython[doc,matplotlib,test,test-extra] ; extra == 'all' + - ipython[doc,matplotlib,terminal,test,test-extra] ; extra == 'all' + - argcomplete>=3.0 ; extra == 'all' requires_python: '>=3.11' - pypi: https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl name: ipython-pygments-lexers @@ -6859,14 +6851,15 @@ packages: version: 1.6.0 sha256: 87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c requires_python: '>=3.5' -- pypi: https://files.pythonhosted.org/packages/97/1a/78b19893197ed7525edfa7f124a461626541e82aec694a468ba97755c24e/netcdf4-1.7.3-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl +- pypi: https://files.pythonhosted.org/packages/7b/7a/a8d32501bb95ecff342004a674720164f95ad616f269450b3bc13dc88ae3/netcdf4-1.7.4-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl name: netcdf4 - version: 1.7.3 - sha256: 0c764ba6f6a1421cab5496097e8a1c4d2e36be2a04880dfd288bb61b348c217e + version: 1.7.4 + sha256: a72c9f58767779ec14cb7451c3b56bdd8fdc027a792fac2062b14e090c5617f3 requires_dist: - cftime - certifi - - numpy + - numpy>=2.3.0 ; platform_machine == 'ARM64' and sys_platform == 'win32' + - numpy>=1.21.2 ; platform_machine != 'ARM64' or sys_platform != 'win32' - cython ; extra == 'tests' - packaging ; extra == 'tests' - pytest ; extra == 'tests' @@ -7045,10 +7038,10 @@ packages: - pkg:pypi/nvidia-ml-py?source=hash-mapping size: 48971 timestamp: 1765209768013 -- pypi: https://files.pythonhosted.org/packages/4a/4e/44dbb46b3d1b0ec61afda8e84837870f2f9ace33c564317d59b70bc19d3e/nvidia_nccl_cu12-2.28.9-py3-none-manylinux_2_18_x86_64.whl +- pypi: https://files.pythonhosted.org/packages/23/2d/609d0392d992259c6dc39881688a7fc13b1397a668bc360fbd68d1396f85/nvidia_nccl_cu12-2.29.2-py3-none-manylinux_2_18_x86_64.whl name: nvidia-nccl-cu12 - version: 2.28.9 - sha256: 485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab + version: 2.29.2 + sha256: 3a9a0bf4142126e0d0ed99ec202579bef8d007601f9fab75af60b10324666b12 requires_python: '>=3' - conda: https://conda.anaconda.org/conda-forge/linux-64/nvtx-0.2.14-py313h07c4f96_0.conda sha256: 9341cb332428242ab938c5fc202008c12430ec43b8b83511d327f14bf8fd6d96 @@ -7501,6 +7494,17 @@ packages: - xarray ; extra == 'dev-optional' - plotly[dev-optional] ; extra == 'dev' requires_python: '>=3.8' +- pypi: https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl + name: pluggy + version: 1.6.0 + sha256: e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 + requires_dist: + - pre-commit ; extra == 'dev' + - tox ; extra == 'dev' + - pytest ; extra == 'testing' + - pytest-benchmark ; extra == 'testing' + - coverage ; extra == 'testing' + requires_python: '>=3.9' - conda: https://conda.anaconda.org/conda-forge/noarch/polars-1.34.0-pyh6a1acc5_0.conda sha256: 7e8bb10f4373202a0be760d9ac74f92c5e7e6095251180642678a8f57f10c58a md5: d398dbcb3312bbebc2b2f3dbb98b4262 @@ -7652,10 +7656,10 @@ packages: - pkg:pypi/psutil?source=hash-mapping size: 501735 timestamp: 1762092897061 -- pypi: https://files.pythonhosted.org/packages/ff/7b/e9a6fa461ef266c5a23485004934b8f08a2a8ddc447802161ea56d9837dd/psygnal-0.15.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl +- pypi: https://files.pythonhosted.org/packages/2d/4f/3593e5adb88a188c798604aed95fbc1479f30230e7f51e8f2c770e6a3832/psygnal-0.15.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl name: psygnal - version: 0.15.0 - sha256: a0172efeb861280bca05673989a4df21624f44344eff20b873d8c9d0edc01350 + version: 0.15.1 + sha256: e9fca977f5335deea39aed22e31d9795983e4f243e59a7d3c4105793adb7693d requires_dist: - wrapt ; extra == 'proxy' - pydantic ; extra == 'pydantic' @@ -8013,6 +8017,26 @@ packages: - pystac[validation]>=1.10.0 - python-dateutil>=2.8.2 requires_python: '>=3.10' +- pypi: https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl + name: pytest + version: 9.0.2 + sha256: 711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b + requires_dist: + - colorama>=0.4 ; sys_platform == 'win32' + - exceptiongroup>=1 ; python_full_version < '3.11' + - iniconfig>=1.0.1 + - packaging>=22 + - pluggy>=1.5,<2 + - pygments>=2.7.2 + - tomli>=1 ; python_full_version < '3.11' + - argcomplete ; extra == 'dev' + - attrs>=19.2 ; extra == 'dev' + - hypothesis>=3.56 ; extra == 'dev' + - mock ; extra == 'dev' + - requests ; extra == 'dev' + - setuptools ; extra == 'dev' + - xmlschema ; extra == 'dev' + requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.11-hc97d973_100_cp313.conda build_number: 100 sha256: 9cf014cf28e93ee242bacfbf664e8b45ae06e50b04291e640abeaeb0cba0364c @@ -8383,19 +8407,19 @@ packages: license: LicenseRef-Custom size: 6143 timestamp: 1765438804958 -- pypi: https://files.pythonhosted.org/packages/7b/84/66c0d9cca2a09074ec2ce6fffa87709ca51b0d197ae742d835e841bac660/rasterio-1.4.4-cp313-cp313-manylinux_2_28_x86_64.whl +- pypi: https://files.pythonhosted.org/packages/48/4a/1af9aa9810fb30668568f2c4dd3eec2412c8e9762b69201d971c509b295e/rasterio-1.5.0-cp313-cp313-manylinux_2_28_x86_64.whl name: rasterio - version: 1.4.4 - sha256: c072450caa96428b1218b030500bb908fd6f09bc013a88969ff81a124b6a112a + version: 1.5.0 + sha256: 08a7580cbb9b3bd320bdf827e10c9b2424d0df066d8eef6f2feb37e154ce0c17 requires_dist: - affine - attrs - certifi - click>=4.0,!=8.2.* - cligj>=0.5 - - numpy>=1.24 - - click-plugins + - numpy>=2 - pyparsing + - rasterio[docs,ipython,plot,s3,test] ; extra == 'all' - ghp-import ; extra == 'docs' - numpydoc ; extra == 'docs' - sphinx ; extra == 'docs' @@ -8404,28 +8428,17 @@ packages: - ipython>=2.0 ; extra == 'ipython' - matplotlib ; extra == 'plot' - boto3>=1.2.4 ; extra == 's3' + - aiohttp ; extra == 'test' - boto3>=1.2.4 ; extra == 'test' - fsspec ; extra == 'test' - hypothesis ; extra == 'test' + - matplotlib ; extra == 'test' - packaging ; extra == 'test' - pytest-cov>=2.2.0 ; extra == 'test' - pytest>=2.8.2 ; extra == 'test' + - requests ; extra == 'test' - shapely ; extra == 'test' - - fsspec ; extra == 'all' - - sphinx-rtd-theme ; extra == 'all' - - ipython>=2.0 ; extra == 'all' - - packaging ; extra == 'all' - - ghp-import ; extra == 'all' - - boto3>=1.2.4 ; extra == 'all' - - matplotlib ; extra == 'all' - - sphinx-click ; extra == 'all' - - sphinx ; extra == 'all' - - pytest>=2.8.2 ; extra == 'all' - - hypothesis ; extra == 'all' - - shapely ; extra == 'all' - - numpydoc ; extra == 'all' - - pytest-cov>=2.2.0 ; extra == 'all' - requires_python: '>=3.10' + requires_python: '>=3.12' - pypi: https://files.pythonhosted.org/packages/f2/98/7e6d147fd16a10a5f821db6e25f192265d6ecca3d82957a4fdd592cad49c/ratelim-0.1.6-py2.py3-none-any.whl name: ratelim version: 0.1.6 @@ -9139,10 +9152,10 @@ packages: - pkg:pypi/terminado?source=hash-mapping size: 22452 timestamp: 1710262728753 -- pypi: https://files.pythonhosted.org/packages/b5/fc/5e2988590ff2e0128eea6446806c904445a44e17256c67141573ea16b5a5/textual-6.11.0-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/84/38/47fab2a5fad163ca4851f7a20eb2442491cc63bf2756ec4ef161bc1461dd/textual-7.0.1-py3-none-any.whl name: textual - version: 6.11.0 - sha256: 9e663b73ed37123a9b13c16a0c85e09ef917a4cfded97814361ed5cccfa40f89 + version: 7.0.1 + sha256: f9b7d16fa9b640bfff2a2008bf31e3f2d4429dc85e07a9583be033840ed15174 requires_dist: - markdown-it-py[linkify]>=2.1.0 - mdit-py-plugins @@ -9429,10 +9442,10 @@ packages: license: BSD-3-Clause size: 508347 timestamp: 1765407086135 -- pypi: https://files.pythonhosted.org/packages/95/20/92e3083b0e854943015bc8a7866e284ead9efadf9bf6809e6fce3b7ded61/ultraplot-1.66.0-py3-none-any.whl +- pypi: https://files.pythonhosted.org/packages/43/6c/b26831b890b37c09882f6406efd31441c8e512bf1efbc967b9d867c5e02b/ultraplot-1.70.0-py3-none-any.whl name: ultraplot - version: 1.66.0 - sha256: 87fecb897ca5c7d54b76ac81e5b8635be45d9c9d42d629469f1d283e6405f9e1 + version: 1.70.0 + sha256: 2b29d1b1e36bd6cf88458370825cfab2c62b9acab706a2cfa434660d7dc4bf74 requires_dist: - numpy>=1.26.0 - matplotlib>=3.9,<3.11 @@ -9496,10 +9509,10 @@ packages: - packaging - narwhals>=1.42 requires_python: '>=3.9' -- pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +- pypi: https://files.pythonhosted.org/packages/6f/61/dc6f4a38cf1b8699f64c57d7f021ca42c39bfe782d8a6eaefb7e8418e925/vl_convert_python-1.9.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl name: vl-convert-python - version: 1.8.0 - sha256: b51264998e8fcc43dbce801484a950cfe6513cdc4c46b20604ef50989855a617 + version: 1.9.0 + sha256: 849e6773a7e05d58ab215386b1065e7713f4846b9ac6b0d743bb3e1b20337231 requires_python: '>=3.7' - pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl name: watchdog diff --git a/pyproject.toml b/pyproject.toml index 6f51207..bae64a3 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ dependencies = [ "pypalettes>=0.2.1,<0.3", "ty>=0.0.2,<0.0.3", "ruff>=0.14.9,<0.15", - "pandas-stubs>=2.3.3.251201,<3", + "pandas-stubs>=2.3.3.251201,<3", "pytest>=9.0.2,<10", ] [project.scripts] diff --git a/scripts/recalculate_test_metrics.py b/scripts/recalculate_test_metrics.py new file mode 100644 index 0000000..882a68c --- /dev/null +++ b/scripts/recalculate_test_metrics.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +"""Recalculate test metrics and confusion matrix for existing training results. + +This script loads previously trained models and recalculates test metrics +and confusion matrices for training runs that were completed before these +outputs were added to the training pipeline. +""" + +import pickle +from pathlib import Path + +import cupy as cp +import numpy as np +import toml +import torch +import xarray as xr +from sklearn import set_config +from sklearn.metrics import confusion_matrix + +from entropice.ml.dataset import DatasetEnsemble +from entropice.utils.paths import RESULTS_DIR + +# Enable array_api_dispatch to handle CuPy/NumPy namespace properly +set_config(array_api_dispatch=True) + + +def recalculate_metrics(results_dir: Path): + """Recalculate test metrics and confusion matrix for a training result. + + Args: + results_dir: Path to the results directory containing the trained model. + + """ + print(f"\nProcessing: {results_dir}") + + # Load the search settings to get training configuration + settings_file = results_dir / "search_settings.toml" + if not settings_file.exists(): + print(" ❌ Missing search_settings.toml, skipping...") + return + + with open(settings_file) as f: + config = toml.load(f) + settings = config["settings"] + + # Check if metrics already exist + test_metrics_file = results_dir / "test_metrics.toml" + cm_file = results_dir / "confusion_matrix.nc" + + # if test_metrics_file.exists() and cm_file.exists(): + # print(" ✓ Metrics already exist, skipping...") + # return + + # Load the best estimator + best_model_file = results_dir / "best_estimator_model.pkl" + if not best_model_file.exists(): + print(" ❌ Missing best_estimator_model.pkl, skipping...") + return + + print(f" Loading best estimator from {best_model_file.name}...") + with open(best_model_file, "rb") as f: + best_estimator = pickle.load(f) + + # Recreate the dataset ensemble + print(" Recreating training dataset...") + dataset_ensemble = DatasetEnsemble( + grid=settings["grid"], + level=settings["level"], + target=settings["target"], + members=settings.get( + "members", + [ + "AlphaEarth", + "ArcticDEM", + "ERA5-yearly", + "ERA5-seasonal", + "ERA5-shoulder", + ], + ), + dimension_filters=settings.get("dimension_filters", {}), + variable_filters=settings.get("variable_filters", {}), + filter_target=settings.get("filter_target", False), + add_lonlat=settings.get("add_lonlat", True), + ) + + task = settings["task"] + model = settings["model"] + device = "torch" if model in ["espa"] else "cuda" + + # Create training data + training_data = dataset_ensemble.create_cat_training_dataset(task=task, device=device) + + # Prepare test data - match training.py's approach + print(" Preparing test data...") + # For XGBoost with CuPy arrays, convert y_test to CPU (same as training.py) + y_test = ( + training_data.y.test.get() + if model == "xgboost" and isinstance(training_data.y.test, cp.ndarray) + else training_data.y.test + ) + + # Compute predictions on the test set (use original device data) + print(" Computing predictions on test set...") + y_pred = best_estimator.predict(training_data.X.test) + + # Use torch + y_pred = torch.as_tensor(y_pred, device="cuda") + y_test = torch.as_tensor(y_test, device="cuda") + + # Compute metrics manually to avoid device issues + print(" Computing test metrics...") + from sklearn.metrics import ( + accuracy_score, + f1_score, + jaccard_score, + precision_score, + recall_score, + ) + + test_metrics = {} + if task == "binary": + test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred)) + test_metrics["recall"] = float(recall_score(y_test, y_pred)) + test_metrics["precision"] = float(precision_score(y_test, y_pred)) + test_metrics["f1"] = float(f1_score(y_test, y_pred)) + test_metrics["jaccard"] = float(jaccard_score(y_test, y_pred)) + else: + test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred)) + test_metrics["f1_macro"] = float(f1_score(y_test, y_pred, average="macro")) + test_metrics["f1_weighted"] = float(f1_score(y_test, y_pred, average="weighted")) + test_metrics["precision_macro"] = float(precision_score(y_test, y_pred, average="macro", zero_division=0)) + test_metrics["precision_weighted"] = float(precision_score(y_test, y_pred, average="weighted", zero_division=0)) + test_metrics["recall_macro"] = float(recall_score(y_test, y_pred, average="macro")) + test_metrics["jaccard_micro"] = float(jaccard_score(y_test, y_pred, average="micro")) + test_metrics["jaccard_macro"] = float(jaccard_score(y_test, y_pred, average="macro")) + test_metrics["jaccard_weighted"] = float(jaccard_score(y_test, y_pred, average="weighted")) + + # Get confusion matrix + print(" Computing confusion matrix...") + labels = list(range(len(training_data.y.labels))) + labels = torch.as_tensor(np.array(labels), device="cuda") + print(" Device of y_test:", getattr(training_data.y.test, "device", "cpu")) + print(" Device of y_pred:", getattr(y_pred, "device", "cpu")) + print(" Device of labels:", getattr(labels, "device", "cpu")) + cm = confusion_matrix(y_test, y_pred, labels=labels) + cm = cm.cpu().numpy() + labels = labels.cpu().numpy() + label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))] + cm_xr = xr.DataArray( + cm, + dims=["true_label", "predicted_label"], + coords={"true_label": label_names, "predicted_label": label_names}, + name="confusion_matrix", + ) + + # Store the test metrics + if not test_metrics_file.exists(): + print(f" Storing test metrics to {test_metrics_file.name}...") + with open(test_metrics_file, "w") as f: + toml.dump({"test_metrics": test_metrics}, f) + else: + print(" ✓ Test metrics already exist") + + # Store the confusion matrix + if True: + # if not cm_file.exists(): + print(f" Storing confusion matrix to {cm_file.name}...") + cm_xr.to_netcdf(cm_file, engine="h5netcdf") + else: + print(" ✓ Confusion matrix already exists") + + print(" ✓ Done!") + + +def main(): + """Find all training results and recalculate metrics for those missing them.""" + print("Searching for training results directories...") + + # Find all results directories + results_dirs = sorted([d for d in RESULTS_DIR.glob("*") if d.is_dir()]) + + print(f"Found {len(results_dirs)} results directories.\n") + + for results_dir in results_dirs: + recalculate_metrics(results_dir) + # try: + # except Exception as e: + # print(f" ❌ Error processing {results_dir.name}: {e}") + # continue + + print("\n✅ All done!") + + +if __name__ == "__main__": + main() diff --git a/scripts/rechunk_zarr.py b/scripts/rechunk_zarr.py new file mode 100644 index 0000000..664ccc9 --- /dev/null +++ b/scripts/rechunk_zarr.py @@ -0,0 +1,58 @@ +import xarray as xr +import zarr +from rich import print +import dask.distributed as dd + +from entropice.utils.paths import get_era5_stores +import entropice.utils.codecs + +def print_info(daily_raw = None, show_vars: bool = True): + if daily_raw is None: + daily_store = get_era5_stores("daily") + daily_raw = xr.open_zarr(daily_store, consolidated=False) + print("=== Daily INFO ===") + print(f" Dims: {daily_raw.sizes}") + numchunks = 1 + chunksizes = {} + approxchunksize = 4 # 4 Bytes = float32 + for d, cs in daily_raw.chunksizes.items(): + numchunks *= len(cs) + chunksizes[d] = max(cs) + approxchunksize *= max(cs) + approxchunksize /= 10e6 # MB + print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total") + print(f" Encoding: {daily_raw.encoding}") + if show_vars: + print(" Variables:") + for var in daily_raw.data_vars: + da = daily_raw[var] + print(f" {var} Encoding:") + print(da.encoding) + print("") + +def rechunk(): + daily_store = get_era5_stores("daily") + daily_raw = xr.open_zarr(daily_store, consolidated=False) + print_info(daily_raw, False) + daily_raw = daily_raw.chunk({ + "time": 120, + "latitude": -1, # Should be 337, + "longitude": -1 # Should be 3600 + }) + print_info(daily_raw, False) + + encoding = entropice.utils.codecs.from_ds(daily_raw) + daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked") + daily_raw.to_zarr(daily_store_rechunked, mode="w", encoding=encoding, consolidated=False) + + + +if __name__ == "__main__": + with ( + dd.LocalCluster(n_workers=1, threads_per_worker=10, memory_limit="100GB") as cluster, + dd.Client(cluster) as client, + ): + print(client) + print(client.dashboard_link) + rechunk() + print("Done.") \ No newline at end of file diff --git a/scripts/rerun_missing_inference.py b/scripts/rerun_missing_inference.py new file mode 100644 index 0000000..cd463c1 --- /dev/null +++ b/scripts/rerun_missing_inference.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +"""Rerun inference for training results that are missing predicted probabilities. + +This script searches through training result directories and identifies those that have +a trained model but are missing inference results. It then loads the model and dataset +configuration, reruns inference, and saves the results. +""" + +import pickle +from pathlib import Path + +import toml +from rich.console import Console +from rich.progress import track + +from entropice.ml.dataset import DatasetEnsemble +from entropice.ml.inference import predict_proba +from entropice.utils.paths import RESULTS_DIR + +console = Console() + + +def find_incomplete_trainings() -> list[Path]: + """Find training result directories missing inference results. + + Returns: + list[Path]: List of directories with trained models but missing predictions. + + """ + incomplete = [] + + if not RESULTS_DIR.exists(): + console.print(f"[yellow]Results directory not found: {RESULTS_DIR}[/yellow]") + return incomplete + + # Search for all training result directories + for result_dir in RESULTS_DIR.glob("*_cv*"): + if not result_dir.is_dir(): + continue + + model_file = result_dir / "best_estimator_model.pkl" + settings_file = result_dir / "search_settings.toml" + predictions_file = result_dir / "predicted_probabilities.parquet" + + # Check if model and settings exist but predictions are missing + if model_file.exists() and settings_file.exists() and not predictions_file.exists(): + incomplete.append(result_dir) + + return incomplete + + +def rerun_inference(result_dir: Path) -> bool: + """Rerun inference for a training result directory. + + Args: + result_dir (Path): Path to the training result directory. + + Returns: + bool: True if successful, False otherwise. + + """ + try: + console.print(f"\n[cyan]Processing: {result_dir.name}[/cyan]") + + # Load settings + settings_file = result_dir / "search_settings.toml" + with open(settings_file) as f: + settings_data = toml.load(f) + + settings = settings_data["settings"] + + # Reconstruct DatasetEnsemble from settings + ensemble = DatasetEnsemble( + grid=settings["grid"], + level=settings["level"], + target=settings["target"], + members=settings["members"], + dimension_filters=settings.get("dimension_filters", {}), + variable_filters=settings.get("variable_filters", {}), + filter_target=settings.get("filter_target", False), + add_lonlat=settings.get("add_lonlat", True), + ) + + # Load trained model + model_file = result_dir / "best_estimator_model.pkl" + with open(model_file, "rb") as f: + clf = pickle.load(f) + + console.print("[green]✓[/green] Loaded model and settings") + + # Get class labels + classes = settings["classes"] + + # Run inference + console.print("[yellow]Running inference...[/yellow]") + preds = predict_proba(ensemble, clf=clf, classes=classes) + + # Save predictions + preds_file = result_dir / "predicted_probabilities.parquet" + preds.to_parquet(preds_file) + + console.print(f"[green]✓[/green] Saved {len(preds)} predictions to {preds_file.name}") + return True + + except Exception as e: + console.print(f"[red]✗ Error processing {result_dir.name}: {e}[/red]") + import traceback + + console.print(f"[red]{traceback.format_exc()}[/red]") + return False + + +def main(): + """Rerun missing inferences for incomplete training results.""" + console.print("[bold blue]Searching for incomplete training results...[/bold blue]") + + incomplete_dirs = find_incomplete_trainings() + + if not incomplete_dirs: + console.print("[green]No incomplete trainings found. All trainings have predictions![/green]") + return + + console.print(f"[yellow]Found {len(incomplete_dirs)} training(s) missing predictions:[/yellow]") + for d in incomplete_dirs: + console.print(f" • {d.name}") + + console.print(f"\n[bold]Processing {len(incomplete_dirs)} training result(s)...[/bold]\n") + + successful = 0 + failed = 0 + + for result_dir in track(incomplete_dirs, description="Rerunning inference"): + if rerun_inference(result_dir): + successful += 1 + else: + failed += 1 + + console.print("\n[bold]Summary:[/bold]") + console.print(f" [green]Successful: {successful}[/green]") + console.print(f" [red]Failed: {failed}[/red]") + + +if __name__ == "__main__": + main() diff --git a/src/entropice/dashboard/plots/hyperparameter_analysis.py b/src/entropice/dashboard/plots/hyperparameter_analysis.py index 4c9100e..76ba038 100644 --- a/src/entropice/dashboard/plots/hyperparameter_analysis.py +++ b/src/entropice/dashboard/plots/hyperparameter_analysis.py @@ -10,9 +10,11 @@ import pandas as pd import pydeck as pdk import streamlit as st +from entropice.dashboard.utils.class_ordering import get_ordered_classes from entropice.dashboard.utils.colors import get_cmap, get_palette from entropice.dashboard.utils.geometry import fix_hex_geometry from entropice.ml.dataset import DatasetEnsemble +from entropice.ml.training import TrainingSettings def render_performance_summary(results: pd.DataFrame, refit_metric: str): @@ -125,7 +127,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str): ) -def render_parameter_distributions(results: pd.DataFrame, settings: dict | None = None): +def render_parameter_distributions(results: pd.DataFrame, settings: TrainingSettings | None = None): """Render histograms of parameter distributions explored. Args: @@ -1152,15 +1154,18 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1 @st.fragment -def render_confusion_matrix_map(result_path: Path, settings: dict): - """Render 3D pydeck map showing confusion matrix results (TP, FP, TN, FN). +def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings): + """Render 3D pydeck map showing prediction results. + + Uses true labels for elevation (height) and different shades of red for incorrect predictions + based on the predicted class. Args: result_path: Path to the training result directory. settings: Settings dictionary containing grid, level, task, and target information. """ - st.subheader("🗺️ Confusion Matrix Spatial Distribution") + st.subheader("🗺️ Prediction Results Map") # Load predicted probabilities preds_file = result_path / "predicted_probabilities.parquet" @@ -1190,62 +1195,41 @@ def render_confusion_matrix_map(result_path: Path, settings: dict): st.error(f"Error loading training data: {e}") return - # Get the labeled cells (those with true labels) - labeled_cells = training_data.dataset[training_data.dataset.index.isin(training_data.y.binned.index)] + # Get all cells from the complete dataset (not just test split) + # Use the full dataset which includes both train and test splits + all_cells = training_data.dataset.copy() # Merge predictions with true labels # Reset index to avoid ambiguity between index and column - labeled_gdf = labeled_cells.copy() - labeled_gdf = labeled_gdf.reset_index().rename(columns={"index": "cell_id"}) - labeled_gdf["true_class"] = training_data.y.binned.loc[labeled_cells.index].to_numpy() + labeled_gdf = all_cells.reset_index().rename(columns={"index": "cell_id"}) + labeled_gdf["true_class"] = training_data.y.binned.loc[all_cells.index].to_numpy() - # Merge with predictions - ensure we keep GeoDataFrame type - merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner") + # Merge with predictions - use left join to keep all cells + merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="left") merged = gpd.GeoDataFrame(merged_df, geometry="geometry", crs=labeled_gdf.crs) + # Mark which cells have predictions (test split) vs not (training split) + merged["in_test_split"] = merged["predicted_class"].notna() + + # For cells without predictions (training split), use true class as predicted class for visualization + merged["predicted_class"] = merged["predicted_class"].fillna(merged["true_class"]) + if len(merged) == 0: st.warning("No matching predictions found for labeled cells.") return - # Determine confusion matrix category - def get_confusion_category(row): - true_label = row["true_class"] - pred_label = row["predicted_class"] + # Mark correct vs incorrect predictions (only meaningful for test split) + merged["is_correct"] = merged["true_class"] == merged["predicted_class"] - if task == "binary": - # For binary classification - if true_label == "RTS" and pred_label == "RTS": - return "True Positive" - elif true_label == "RTS" and pred_label == "No-RTS": - return "False Negative" - elif true_label == "No-RTS" and pred_label == "RTS": - return "False Positive" - else: # true_label == "No-RTS" and pred_label == "No-RTS" - return "True Negative" - else: - # For multiclass (count/density) - if true_label == pred_label: - return "Correct" - else: - return "Incorrect" - - merged["confusion_category"] = merged.apply(get_confusion_category, axis=1) + # Get ordered class labels for the task + ordered_classes = get_ordered_classes(task) # Create controls - col1, col2 = st.columns([3, 1]) + col1, col2, col3 = st.columns([2, 1, 1]) with col1: - # Filter by confusion category - if task == "binary": - categories = [ - "All", - "True Positive", - "False Positive", - "True Negative", - "False Negative", - ] - else: - categories = ["All", "Correct", "Incorrect"] + # Filter by prediction correctness and split + categories = ["All", "Test Split Only", "Training Split Only", "Correct (Test)", "Incorrect (Test)"] selected_category = st.selectbox( "Filter by Category", @@ -1263,10 +1247,26 @@ def render_confusion_matrix_map(result_path: Path, settings: dict): key="confusion_map_opacity", ) + with col3: + line_width = st.slider( + "Line Width", + min_value=0.5, + max_value=3.0, + value=1.0, + step=0.5, + key="confusion_map_line_width", + ) + # Filter data if needed - if selected_category != "All": - display_gdf = merged[merged["confusion_category"] == selected_category].copy() - else: + if selected_category == "Test Split Only": + display_gdf = merged[merged["in_test_split"]].copy() + elif selected_category == "Training Split Only": + display_gdf = merged[~merged["in_test_split"]].copy() + elif selected_category == "Correct (Test)": + display_gdf = merged[merged["is_correct"] & merged["in_test_split"]].copy() + elif selected_category == "Incorrect (Test)": + display_gdf = merged[~merged["is_correct"] & merged["in_test_split"]].copy() + else: # "All" display_gdf = merged.copy() if len(display_gdf) == 0: @@ -1280,49 +1280,72 @@ def render_confusion_matrix_map(result_path: Path, settings: dict): if grid == "hex": display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry) - # Assign colors based on confusion category - if task == "binary": - color_map = { - "True Positive": [46, 204, 113], # Green - "False Positive": [231, 76, 60], # Red - "True Negative": [52, 152, 219], # Blue - "False Negative": [241, 196, 15], # Yellow - } - else: - color_map = { - "Correct": [46, 204, 113], # Green - "Incorrect": [231, 76, 60], # Red - } + # Get red material colormap for incorrect predictions + red_cmap = get_cmap("red_predictions") # Use red_material palette + n_classes = len(ordered_classes) - display_gdf_wgs84["fill_color"] = display_gdf_wgs84["confusion_category"].map(color_map) + # Assign colors based on correctness + def get_color(row): + if row["is_correct"]: + # Green for correct predictions + return [46, 204, 113] + else: + # Different shades of red for each predicted class (ordered) + pred_class = row["predicted_class"] + if pred_class in ordered_classes: + class_idx = ordered_classes.index(pred_class) + # Sample from red colormap based on class index + color_value = red_cmap(class_idx / max(n_classes - 1, 1)) + return [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)] + else: + # Fallback red if class not found + return [231, 76, 60] - # Add elevation based on confusion category (higher for errors) - if task == "binary": - elevation_map = { - "True Positive": 0.8, - "False Positive": 1.0, - "True Negative": 0.3, - "False Negative": 1.0, - } - else: - elevation_map = { - "Correct": 0.5, - "Incorrect": 1.0, - } + display_gdf_wgs84["fill_color"] = display_gdf_wgs84.apply(get_color, axis=1) - display_gdf_wgs84["elevation"] = display_gdf_wgs84["confusion_category"].map(elevation_map) + # Add line color based on split: blue for test split, orange for training split + def get_line_color(row): + if row["in_test_split"]: + return [52, 152, 219] # Blue for test split + else: + return [230, 126, 34] # Orange for training split + + display_gdf_wgs84["line_color"] = display_gdf_wgs84.apply(get_line_color, axis=1) + + # Add elevation based on TRUE label (not predicted) + # Map each true class to a height based on its position in the ordered list + def get_elevation(row): + true_class = row["true_class"] + if true_class in ordered_classes: + class_idx = ordered_classes.index(true_class) + # Normalize to 0-1 range based on class position + return (class_idx + 1) / n_classes + else: + return 0.5 # Default elevation + + display_gdf_wgs84["elevation"] = display_gdf_wgs84.apply(get_elevation, axis=1) # Convert to GeoJSON format geojson_data = [] for _, row in display_gdf_wgs84.iterrows(): + # Determine split and status for tooltip + split_name = "Test Split" if row["in_test_split"] else "Training Split" + if row["in_test_split"]: + status = "✓ Correct" if row["is_correct"] else "✗ Incorrect" + else: + status = "(No prediction - training data)" + feature = { "type": "Feature", "geometry": row["geometry"].__geo_interface__, "properties": { "true_class": str(row["true_class"]), - "predicted_class": str(row["predicted_class"]), - "confusion_category": str(row["confusion_category"]), + "predicted_class": str(row["predicted_class"]) if row["in_test_split"] else "N/A", + "is_correct": bool(row["is_correct"]), + "split": split_name, + "status": status, "fill_color": row["fill_color"], + "line_color": row["line_color"], "elevation": float(row["elevation"]), }, } @@ -1338,8 +1361,8 @@ def render_confusion_matrix_map(result_path: Path, settings: dict): extruded=True, wireframe=False, get_fill_color="properties.fill_color", - get_line_color=[80, 80, 80], - line_width_min_pixels=0.5, + get_line_color="properties.line_color", + line_width_min_pixels=line_width, get_elevation="properties.elevation", elevation_scale=500000, pickable=True, @@ -1353,9 +1376,10 @@ def render_confusion_matrix_map(result_path: Path, settings: dict): layers=[layer], initial_view_state=view_state, tooltip={ - "html": "True Label: {true_class}
" + "html": "Split: {split}
" + "True Label: {true_class}
" "Predicted Label: {predicted_class}
" - "Category: {confusion_category}", + "Status: {status}", "style": {"backgroundColor": "steelblue", "color": "white"}, }, map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json", @@ -1365,56 +1389,240 @@ def render_confusion_matrix_map(result_path: Path, settings: dict): st.pydeck_chart(deck) # Show statistics - col1, col2, col3 = st.columns(3) + col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Labeled Cells", len(merged)) - if task == "binary": - with col2: - tp = len(merged[merged["confusion_category"] == "True Positive"]) - fp = len(merged[merged["confusion_category"] == "False Positive"]) - tn = len(merged[merged["confusion_category"] == "True Negative"]) - fn = len(merged[merged["confusion_category"] == "False Negative"]) + with col2: + test_count = len(merged[merged["in_test_split"]]) + st.metric("Test Split", test_count) - accuracy = (tp + tn) / len(merged) if len(merged) > 0 else 0 - st.metric("Accuracy", f"{accuracy:.2%}") + with col3: + train_count = len(merged[~merged["in_test_split"]]) + st.metric("Training Split", train_count) - with col3: - precision = tp / (tp + fp) if (tp + fp) > 0 else 0 - recall = tp / (tp + fn) if (tp + fn) > 0 else 0 - f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 - st.metric("F1 Score", f"{f1:.3f}") - - # Show confusion matrix counts - st.caption(f"TP: {tp} | FP: {fp} | TN: {tn} | FN: {fn}") - else: - with col2: - correct = len(merged[merged["confusion_category"] == "Correct"]) - accuracy = correct / len(merged) if len(merged) > 0 else 0 - st.metric("Accuracy", f"{accuracy:.2%}") - - with col3: - incorrect = len(merged[merged["confusion_category"] == "Incorrect"]) - st.metric("Incorrect", incorrect) + with col4: + test_cells = merged[merged["in_test_split"]] + if len(test_cells) > 0: + correct = len(test_cells[test_cells["is_correct"]]) + accuracy = correct / len(test_cells) + st.metric("Test Accuracy", f"{accuracy:.2%}") + else: + st.metric("Test Accuracy", "N/A") # Add legend with st.expander("Legend", expanded=True): - st.markdown("**Confusion Matrix Categories:**") + # Split indicators (border colors) + st.markdown("**Data Split (Border Color):**") - for category, color in color_map.items(): - count = len(merged[merged["confusion_category"] == category]) - percentage = count / len(merged) * 100 if len(merged) > 0 else 0 + test_count = len(merged[merged["in_test_split"]]) + train_count = len(merged[~merged["in_test_split"]]) - st.markdown( - f'
' - f'
' - f"{category}: {count} ({percentage:.1f}%)
", - unsafe_allow_html=True, - ) + st.markdown( + f'
' + f'
' + f"Test Split ({test_count} cells, {test_count / len(merged) * 100:.1f}%)
", + unsafe_allow_html=True, + ) + + st.markdown( + f'
' + f'
' + f"Training Split ({train_count} cells, {train_count / len(merged) * 100:.1f}%)
", + unsafe_allow_html=True, + ) + + st.markdown("---") + st.markdown("**Fill Color (Prediction Results):**") + + # Correct predictions (test split only) + test_cells = merged[merged["in_test_split"]] + correct = len(test_cells[test_cells["is_correct"]]) if len(test_cells) > 0 else 0 + incorrect = len(test_cells[~test_cells["is_correct"]]) if len(test_cells) > 0 else 0 + + st.markdown( + f'
' + f'
' + f"Correct Predictions (Test) ({correct} cells, {correct / len(test_cells) * 100 if len(test_cells) > 0 else 0:.1f}%)
", + unsafe_allow_html=True, + ) + + # Incorrect predictions by predicted class (shades of red) + st.markdown( + f"Incorrect Predictions by Predicted Class (Test) ({incorrect} cells):", unsafe_allow_html=True + ) + + for class_idx, class_label in enumerate(ordered_classes): + # Get count of incorrect predictions for this predicted class (test split only) + count = len(test_cells[(~test_cells["is_correct"]) & (test_cells["predicted_class"] == class_label)]) + if count > 0: + # Get color for this predicted class + color_value = red_cmap(class_idx / max(n_classes - 1, 1)) + rgb = [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)] + + percentage = count / incorrect * 100 if incorrect > 0 else 0 + + st.markdown( + f'
' + f'
' + f"Predicted as {class_label}: {count} ({percentage:.1f}%)
", + unsafe_allow_html=True, + ) + + # Note about training split + st.markdown( + f'
' + f"Note: Training split cells ({train_count}) are shown with their true labels (green fill) " + f"since predictions are only available for the test split.
", + unsafe_allow_html=True, + ) st.markdown("---") st.markdown("**Elevation (3D):**") - st.markdown("Height represents prediction confidence: Errors are elevated higher than correct predictions.") + + # Show elevation mapping for each true class + st.markdown("Height represents the true label:", unsafe_allow_html=True) + for class_idx, class_label in enumerate(ordered_classes): + elevation_value = (class_idx + 1) / n_classes + height_km = elevation_value * 500 # Since elevation_scale is 500000 + st.markdown( + f'
{class_label}: {height_km:.0f} km
', + unsafe_allow_html=True, + ) st.info("💡 Rotate the map by holding Ctrl/Cmd and dragging.") + + +def render_confusion_matrix_heatmap(confusion_matrix: "xr.DataArray", task: str): + """Render confusion matrix as an interactive heatmap. + + Args: + confusion_matrix: xarray DataArray with dimensions (true_label, predicted_label). + task: Task type ('binary' or 'multiclass'). + + """ + import plotly.express as px + + # Convert to DataFrame for plotting + cm_df = confusion_matrix.to_pandas() + + # Get labels (convert numeric labels to semantic labels if possible) + true_labels = confusion_matrix.coords["true_label"].values + pred_labels = confusion_matrix.coords["predicted_label"].values + + # For binary classification, map 0/1 to No-RTS/RTS + if task == "binary": + label_map = {0: "No-RTS", 1: "RTS"} + true_labels_str = [label_map.get(int(label), str(label)) for label in true_labels] + pred_labels_str = [label_map.get(int(label), str(label)) for label in pred_labels] + else: + # For multiclass, use numeric labels as is + true_labels_str = [str(label) for label in true_labels] + pred_labels_str = [str(label) for label in pred_labels] + + # Rename DataFrame indices and columns for display + cm_df.index = true_labels_str + cm_df.columns = pred_labels_str + + # Store raw counts for annotations + cm_counts = cm_df.copy() + + # Normalize by row (true label) to get percentages + cm_normalized = cm_df.div(cm_df.sum(axis=1), axis=0) + + # Create custom text annotations showing both percentage and count + text_annotations = [] + for i, true_label in enumerate(true_labels_str): + row_annotations = [] + for j, pred_label in enumerate(pred_labels_str): + count = int(cm_counts.iloc[i, j]) + percentage = cm_normalized.iloc[i, j] * 100 + row_annotations.append(f"{percentage:.1f}%
({count:,})") + text_annotations.append(row_annotations) + + # Create heatmap with normalized values + fig = px.imshow( + cm_normalized, + labels=dict(x="Predicted Label", y="True Label", color="Proportion"), + x=pred_labels_str, + y=true_labels_str, + color_continuous_scale="Blues", + aspect="auto", + zmin=0, + zmax=1, + ) + + # Update with custom annotations + fig.update_traces( + text=text_annotations, + texttemplate="%{text}", + textfont={"size": 12}, + ) + + # Update layout for better readability + fig.update_layout( + title="Confusion Matrix (Normalized by True Label)", + xaxis_title="Predicted Label", + yaxis_title="True Label", + height=500, + ) + + # Update colorbar to show percentage + fig.update_coloraxes( + colorbar=dict( + title="Proportion", + tickformat=".0%", + ) + ) + + st.plotly_chart(fig, width="stretch") + + st.caption( + "📊 Values show **row-normalized percentages** (percentage of each true class predicted as each label). " + "Raw counts shown in parentheses." + ) + + # Calculate and display metrics from confusion matrix + col1, col2, col3 = st.columns(3) + + total_samples = int(cm_df.values.sum()) + correct_predictions = int(np.trace(cm_df.values)) + accuracy = correct_predictions / total_samples if total_samples > 0 else 0 + + with col1: + st.metric("Total Samples", f"{total_samples:,}") + + with col2: + st.metric("Correct Predictions", f"{correct_predictions:,}") + + with col3: + st.metric("Accuracy", f"{accuracy:.2%}") + + # Add detailed breakdown for binary classification + if task == "binary": + st.markdown("#### Binary Classification Metrics") + + # Extract TP, TN, FP, FN from confusion matrix + # Assuming 0=No-RTS (negative), 1=RTS (positive) + tn = int(cm_df.iloc[0, 0]) + fp = int(cm_df.iloc[0, 1]) + fn = int(cm_df.iloc[1, 0]) + tp = int(cm_df.iloc[1, 1]) + + col1, col2, col3, col4 = st.columns(4) + + with col1: + st.metric("True Positives (TP)", f"{tp:,}") + + with col2: + st.metric("True Negatives (TN)", f"{tn:,}") + + with col3: + st.metric("False Positives (FP)", f"{fp:,}") + + with col4: + st.metric("False Negatives (FN)", f"{fn:,}") diff --git a/src/entropice/dashboard/plots/inference.py b/src/entropice/dashboard/plots/inference.py index 4713cdc..57ba9f9 100644 --- a/src/entropice/dashboard/plots/inference.py +++ b/src/entropice/dashboard/plots/inference.py @@ -6,6 +6,7 @@ import plotly.graph_objects as go import pydeck as pdk import streamlit as st +from entropice.dashboard.utils.class_ordering import get_ordered_classes, sort_class_series from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.geometry import fix_hex_geometry from entropice.dashboard.utils.loaders import TrainingResult @@ -64,8 +65,9 @@ def render_class_distribution_histogram(predictions_gdf: gpd.GeoDataFrame, task: """ st.subheader("📊 Predicted Class Distribution") - # Get class counts - class_counts = predictions_gdf["predicted_class"].value_counts().sort_index() + # Get class counts and order them properly + class_counts = predictions_gdf["predicted_class"].value_counts() + class_counts = sort_class_series(class_counts, task) # Get colors based on task categories = class_counts.index.tolist() @@ -121,7 +123,7 @@ def render_spatial_distribution_stats(predictions_gdf: gpd.GeoDataFrame): st.subheader("🌍 Spatial Coverage") # Calculate spatial extent - bounds = predictions_gdf.total_bounds + bounds = predictions_gdf.to_crs("EPSG:4326").total_bounds col1, col2, col3, col4 = st.columns(4) @@ -193,8 +195,8 @@ def render_inference_map(result: TrainingResult): col1, col2, col3 = st.columns([2, 2, 1]) with col1: - # Get unique classes for filtering - all_classes = sorted(preds_gdf["predicted_class"].unique()) + # Get unique classes for filtering (properly ordered) + all_classes = get_ordered_classes(task, preds_gdf["predicted_class"].unique().tolist()) filter_options = ["All Classes", *all_classes] selected_filter = st.selectbox( @@ -240,11 +242,13 @@ def render_inference_map(result: TrainingResult): if grid == "hex": display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry) - # Assign colors based on predicted class - colors_palette = get_palette(task, len(all_classes)) + # Assign colors based on predicted class (using canonical ordering) + # Get all possible classes for this task to ensure consistent colors + canonical_classes = get_ordered_classes(task) + colors_palette = get_palette(task, len(canonical_classes)) - # Create color mapping for all classes - color_map = {cls: colors_palette[i] for i, cls in enumerate(all_classes)} + # Create color mapping for canonical classes + color_map = {cls: colors_palette[i] for i, cls in enumerate(canonical_classes)} # Convert hex colors to RGB def hex_to_rgb(hex_color): @@ -349,8 +353,9 @@ def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str): """ st.subheader("🔍 Class Comparison") - # Get class distribution + # Get class distribution and order properly class_counts = predictions_gdf["predicted_class"].value_counts() + class_counts = sort_class_series(class_counts, task) if len(class_counts) < 2: st.info("Need at least 2 classes for comparison.") @@ -362,7 +367,13 @@ def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str): with col1: st.markdown("**Class Proportions") - colors = get_palette(task, len(class_counts)) + # Get colors matching canonical order + canonical_classes = get_ordered_classes(task) + all_colors = get_palette(task, len(canonical_classes)) + color_map = {cls: all_colors[i] for i, cls in enumerate(canonical_classes)} + + # Extract colors for available classes in proper order + colors = [color_map[cls] for cls in class_counts.index] fig = go.Figure( data=[ diff --git a/src/entropice/dashboard/plots/model_state.py b/src/entropice/dashboard/plots/model_state.py index c0b6981..86d1465 100644 --- a/src/entropice/dashboard/plots/model_state.py +++ b/src/entropice/dashboard/plots/model_state.py @@ -271,7 +271,7 @@ def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart: return chart -def plot_era5_time_heatmap(era5_array: xr.DataArray) -> alt.Chart: +def plot_era5_time_heatmap(era5_array: xr.DataArray) -> alt.Chart | None: """Create a heatmap showing ERA5 feature weights by variable and year (averaging over season). This is specifically for seasonal/shoulder data to show temporal trends. diff --git a/src/entropice/dashboard/plots/overview.py b/src/entropice/dashboard/plots/overview.py index 0770f01..91de690 100644 --- a/src/entropice/dashboard/plots/overview.py +++ b/src/entropice/dashboard/plots/overview.py @@ -48,50 +48,78 @@ def create_sample_count_heatmap( def create_sample_count_bar_chart( sample_df: pd.DataFrame, - task_colors: list[str] | None = None, + target_color_maps: dict[str, list[str]] | None = None, ) -> go.Figure: """Create bar chart showing sample counts by grid, target, and task. Args: sample_df: DataFrame with columns: Grid, Target, Task, Samples (Coverage). - task_colors: Optional color palette for tasks. If None, uses default Plotly colors. + target_color_maps: Optional dictionary mapping target names ("rts", "mllabels") to color palettes. + If None, uses default Plotly colors. Returns: Plotly Figure object containing the bar chart visualization. """ - fig = px.bar( - sample_df, - x="Grid", - y="Samples (Coverage)", - color="Task", - facet_col="Target", - barmode="group", - title="Sample Counts by Grid Configuration and Target Dataset", - labels={ - "Grid": "Grid Configuration", - "Samples (Coverage)": "Number of Samples", - }, - color_discrete_sequence=task_colors, - height=500, + # Create subplots manually to have better control over colors + from plotly.subplots import make_subplots + + targets = sorted(sample_df["Target"].unique()) + tasks = sorted(sample_df["Task"].unique()) + + fig = make_subplots( + rows=1, + cols=len(targets), + subplot_titles=[f"Target: {target}" for target in targets], + shared_yaxes=True, + ) + + for col_idx, target in enumerate(targets, 1): + target_data = sample_df[sample_df["Target"] == target] + # Get color palette for this target + colors = target_color_maps.get(target, None) if target_color_maps else None + + for task_idx, task in enumerate(tasks): + task_data = target_data[target_data["Task"] == task] + color = colors[task_idx] if colors and task_idx < len(colors) else None + + # Create a unique legendgroup per task so colors are consistent + fig.add_trace( + go.Bar( + x=task_data["Grid"], + y=task_data["Samples (Coverage)"], + name=task, + marker_color=color, + legendgroup=task, # Group by task name + showlegend=(col_idx == 1), # Only show legend for first subplot + ), + row=1, + col=col_idx, + ) + + fig.update_layout( + title_text="Training Sample Counts by Grid Configuration and Target Dataset", + barmode="group", + height=500, + showlegend=True, ) - # Update facet labels to be cleaner - fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1])) fig.update_xaxes(tickangle=-45) + fig.update_yaxes(title_text="Number of Samples", row=1, col=1) return fig def create_feature_count_stacked_bar( breakdown_df: pd.DataFrame, - source_colors: list[str] | None = None, + source_color_map: dict[str, str] | None = None, ) -> go.Figure: """Create stacked bar chart showing feature counts by data source. Args: breakdown_df: DataFrame with columns: Grid, Data Source, Number of Features. - source_colors: Optional color palette for data sources. If None, uses default Plotly colors. + source_color_map: Optional dictionary mapping data source names to specific colors. + If None, uses default Plotly colors. Returns: Plotly Figure object containing the stacked bar chart visualization. @@ -103,12 +131,12 @@ def create_feature_count_stacked_bar( y="Number of Features", color="Data Source", barmode="stack", - title="Total Features by Data Source Across Grid Configurations", + title="Input Features by Data Source Across Grid Configurations", labels={ "Grid": "Grid Configuration", "Number of Features": "Number of Features", }, - color_discrete_sequence=source_colors, + color_discrete_map=source_color_map, text_auto=False, ) @@ -136,7 +164,7 @@ def create_inference_cells_bar( x="Grid", y="Inference Cells", color="Grid", - title="Inference Cells by Grid Configuration", + title="Spatial Coverage (Grid Cells with Complete Data)", labels={ "Grid": "Grid Configuration", "Inference Cells": "Number of Cells", @@ -170,7 +198,7 @@ def create_total_samples_bar( x="Grid", y="Total Samples", color="Grid", - title="Total Samples by Grid Configuration", + title="Training Samples (Binary Task)", labels={ "Grid": "Grid Configuration", "Total Samples": "Number of Samples", @@ -188,14 +216,15 @@ def create_total_samples_bar( def create_feature_breakdown_donut( grid_data: pd.DataFrame, grid_config: str, - source_colors: list[str] | None = None, + source_color_map: dict[str, str] | None = None, ) -> go.Figure: """Create donut chart for feature breakdown by data source for a specific grid. Args: grid_data: DataFrame with columns: Data Source, Number of Features. grid_config: Grid configuration name for the title. - source_colors: Optional color palette for data sources. If None, uses default Plotly colors. + source_color_map: Optional dictionary mapping data source names to specific colors. + If None, uses default Plotly colors. Returns: Plotly Figure object containing the donut chart visualization. @@ -207,7 +236,8 @@ def create_feature_breakdown_donut( values="Number of Features", title=grid_config, hole=0.4, - color_discrete_sequence=source_colors, + color_discrete_map=source_color_map, + color="Data Source", ) fig.update_traces(textposition="inside", textinfo="percent") @@ -218,13 +248,14 @@ def create_feature_breakdown_donut( def create_feature_distribution_pie( breakdown_df: pd.DataFrame, - source_colors: list[str] | None = None, + source_color_map: dict[str, str] | None = None, ) -> go.Figure: """Create pie chart for feature distribution by data source. Args: breakdown_df: DataFrame with columns: Data Source, Number of Features. - source_colors: Optional color palette for data sources. If None, uses default Plotly colors. + source_color_map: Optional dictionary mapping data source names to specific colors. + If None, uses default Plotly colors. Returns: Plotly Figure object containing the pie chart visualization. @@ -236,7 +267,8 @@ def create_feature_distribution_pie( values="Number of Features", title="Feature Distribution by Data Source", hole=0.4, - color_discrete_sequence=source_colors, + color_discrete_map=source_color_map, + color="Data Source", ) fig.update_traces(textposition="inside", textinfo="percent+label") diff --git a/src/entropice/dashboard/utils/class_ordering.py b/src/entropice/dashboard/utils/class_ordering.py new file mode 100644 index 0000000..79fd6f4 --- /dev/null +++ b/src/entropice/dashboard/utils/class_ordering.py @@ -0,0 +1,70 @@ +"""Utilities for ordering predicted classes consistently across visualizations. + +This module leverages the canonical class labels defined in the ML dataset module +to ensure consistent ordering across all visualizations. +""" + +import pandas as pd + +from entropice.utils.types import Task + +# Canonical orderings imported from the ML pipeline +# Binary labels are defined inline in dataset.py: {False: "No RTS", True: "RTS"} +# Count/Density labels are defined in the bin_values function +BINARY_LABELS = ["No RTS", "RTS"] +COUNT_LABELS = ["None", "Very Few", "Few", "Several", "Many", "Very Many"] +DENSITY_LABELS = ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"] + +CLASS_ORDERINGS: dict[Task | str, list[str]] = { + "binary": BINARY_LABELS, + "count": COUNT_LABELS, + "density": DENSITY_LABELS, +} + + +def get_ordered_classes(task: Task | str, available_classes: list[str] | None = None) -> list[str]: + """Get properly ordered class labels for a given task. + + This uses the same canonical ordering as defined in the ML dataset module, + ensuring consistency between training and inference visualizations. + + Args: + task: Task type ('binary', 'count', 'density'). + available_classes: Optional list of available classes to filter and order. + If None, returns all canonical classes for the task. + + Returns: + List of class labels in proper order. + + Examples: + >>> get_ordered_classes("binary") + ['No RTS', 'RTS'] + >>> get_ordered_classes("count", ["None", "Few", "Several"]) + ['None', 'Few', 'Several'] + + """ + canonical_order = CLASS_ORDERINGS[task] + + if available_classes is None: + return canonical_order + + # Filter canonical order to only include available classes, preserving order + return [cls for cls in canonical_order if cls in available_classes] + + +def sort_class_series(series: pd.Series, task: Task | str) -> pd.Series: + """Sort a pandas Series with class labels according to canonical ordering. + + Args: + series: Pandas Series with class labels as index. + task: Task type ('binary', 'count', 'density'). + + Returns: + Sorted Series with classes in canonical order. + + """ + available_classes = series.index.tolist() + ordered_classes = get_ordered_classes(task, available_classes) + + # Reindex to get proper order + return series.reindex(ordered_classes) diff --git a/src/entropice/dashboard/utils/geometry.py b/src/entropice/dashboard/utils/geometry.py index 359db95..d9d7cef 100644 --- a/src/entropice/dashboard/utils/geometry.py +++ b/src/entropice/dashboard/utils/geometry.py @@ -1,13 +1,9 @@ """Geometry utilities for dashboard visualizations.""" +import antimeridian import streamlit as st from shapely.geometry import shape -try: - import antimeridian -except ImportError: - antimeridian = None - def fix_hex_geometry(geom): """Fix hexagon geometry crossing the antimeridian. @@ -23,13 +19,9 @@ def fix_hex_geometry(geom): Fixed geometry object with antimeridian issues resolved. Note: - If the antimeridian library is not available or an error occurs, - returns the original geometry unchanged. + If an error occurs, returns the original geometry unchanged. """ - if antimeridian is None: - return geom - try: return shape(antimeridian.fix_shape(geom)) except ValueError as e: diff --git a/src/entropice/dashboard/utils/loaders.py b/src/entropice/dashboard/utils/loaders.py index bb3b467..93f3c79 100644 --- a/src/entropice/dashboard/utils/loaders.py +++ b/src/entropice/dashboard/utils/loaders.py @@ -35,6 +35,8 @@ class TrainingResult: path: Path settings: TrainingSettings results: pd.DataFrame + metrics: dict[str, float] + confusion_matrix: xr.DataArray created_at: float available_metrics: list[str] @@ -44,23 +46,32 @@ class TrainingResult: result_file = result_path / "search_results.parquet" preds_file = result_path / "predicted_probabilities.parquet" settings_file = result_path / "search_settings.toml" + metrics_file = result_path / "test_metrics.toml" + confusion_matrix_file = result_path / "confusion_matrix.nc" if not result_file.exists(): raise FileNotFoundError(f"Missing results file in {result_path}") if not settings_file.exists(): raise FileNotFoundError(f"Missing settings file in {result_path}") if not preds_file.exists(): raise FileNotFoundError(f"Missing predictions file in {result_path}") + if not metrics_file.exists(): + raise FileNotFoundError(f"Missing metrics file in {result_path}") + if not confusion_matrix_file.exists(): + raise FileNotFoundError(f"Missing confusion matrix file in {result_path}") created_at = result_path.stat().st_ctime settings = TrainingSettings(**(toml.load(settings_file)["settings"])) results = pd.read_parquet(result_file) - + metrics = toml.load(metrics_file)["test_metrics"] + confusion_matrix = xr.open_dataarray(confusion_matrix_file, engine="h5netcdf") available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")] return cls( path=result_path, settings=settings, results=results, + metrics=metrics, + confusion_matrix=confusion_matrix, created_at=created_at, available_metrics=available_metrics, ) @@ -126,7 +137,7 @@ def load_all_training_results() -> list[TrainingResult]: try: training_result = TrainingResult.from_path(result_path) except FileNotFoundError as e: - st.warning(f"Skipping incomplete training result at {result_path}: {e}") + st.warning(f"Skipping incomplete training result: {e}") continue training_results.append(training_result) diff --git a/src/entropice/dashboard/views/inference_page.py b/src/entropice/dashboard/views/inference_page.py index b5f2835..0ddd6b3 100644 --- a/src/entropice/dashboard/views/inference_page.py +++ b/src/entropice/dashboard/views/inference_page.py @@ -14,7 +14,6 @@ from entropice.dashboard.plots.inference import ( from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results -@st.fragment def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult: """Render sidebar for training run selection. diff --git a/src/entropice/dashboard/views/overview_page.py b/src/entropice/dashboard/views/overview_page.py index 8c0d841..799a392 100644 --- a/src/entropice/dashboard/views/overview_page.py +++ b/src/entropice/dashboard/views/overview_page.py @@ -13,8 +13,6 @@ from entropice.dashboard.plots.overview import ( create_feature_distribution_pie, create_inference_cells_bar, create_sample_count_bar_chart, - create_sample_count_heatmap, - create_total_samples_bar, ) from entropice.dashboard.utils.colors import get_palette from entropice.dashboard.utils.loaders import load_all_training_results @@ -32,11 +30,9 @@ from entropice.utils.types import ( def render_sample_count_overview(): """Render overview of sample counts per task+target+grid+level combination.""" - st.subheader("📊 Sample Counts by Configuration") - st.markdown( """ - This visualization shows the number of available samples for each combination of: + This visualization shows the number of available training samples for each combination of: - **Task**: binary, count, density - **Target Dataset**: darts_rts, darts_mllabels - **Grid System**: hex, healpix @@ -47,68 +43,39 @@ def render_sample_count_overview(): # Get sample count DataFrame from cache all_stats = load_all_default_dataset_statistics() sample_df = DatasetStatistics.get_sample_count_df(all_stats) - target_datasets = ["darts_rts", "darts_mllabels"] - # Create tabs for different views - tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"]) + # Get color palettes for each target dataset + n_tasks = sample_df["Task"].nunique() + target_color_maps = { + "rts": get_palette("task_types", n_colors=n_tasks), + "mllabels": get_palette("data_sources", n_colors=n_tasks), + } - with tab1: - st.markdown("### Sample Counts Heatmap") - st.markdown("Showing counts of samples with coverage") + # Create and display bar chart + fig = create_sample_count_bar_chart(sample_df, target_color_maps=target_color_maps) + st.plotly_chart(fig, use_container_width=True) - # Create heatmap for each target dataset - for target in target_datasets: - target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")] + # Display full table with formatting + st.markdown("#### Detailed Sample Counts") + display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy() - # Pivot for heatmap: Grid x Task - pivot_df = target_df.pivot_table( - index="Grid", - columns="Task", - values="Samples (Coverage)", - aggfunc="mean", - ) + # Format numbers with commas + display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}") + # Format coverage as percentage with 2 decimal places + display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%") - # Sort index by grid type and level - sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid") - pivot_df = pivot_df.reindex(sort_order.sort_values("Grid_Level_Sort").index) - - # Get color palette for sample counts - sample_colors = get_palette(f"sample_counts_{target}", n_colors=10) - - # Create and display heatmap - fig = create_sample_count_heatmap(pivot_df, target, colorscale=sample_colors) - st.plotly_chart(fig, width="stretch") - - with tab2: - st.markdown("### Sample Counts Bar Chart") - st.markdown("Showing counts of samples with coverage") - - # Get color palette for tasks - n_tasks = sample_df["Task"].nunique() - task_colors = get_palette("task_types", n_colors=n_tasks) - - # Create and display bar chart - fig = create_sample_count_bar_chart(sample_df, task_colors=task_colors) - st.plotly_chart(fig, width="stretch") - - with tab3: - st.markdown("### Detailed Sample Counts") - - # Display full table with formatting - display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy() - - # Format numbers with commas - display_df["Samples (Coverage)"] = display_df["Samples (Coverage)"].apply(lambda x: f"{x:,}") - # Format coverage as percentage with 2 decimal places - display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%") - - st.dataframe(display_df, hide_index=True, width="stretch") + st.dataframe(display_df, hide_index=True, use_container_width=True) def render_feature_count_comparison(): """Render static comparison of feature counts across all grid configurations.""" - st.markdown("### Feature Count Comparison Across Grid Configurations") - st.markdown("Comparing feature counts for all grid configurations with all data sources enabled") + st.markdown( + """ + Comparing dataset characteristics for all grid configurations with all data sources enabled. + - **Features**: Total number of input features from all data sources + - **Spatial Coverage**: Number of grid cells with complete data coverage + """ + ) # Get data from cache all_stats = load_all_default_dataset_statistics() @@ -116,87 +83,44 @@ def render_feature_count_comparison(): breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats) breakdown_df = breakdown_df.sort_values("Grid_Level_Sort") - # Create tabs for different comparison views - comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"]) + # Get all unique data sources and create color map + unique_sources = sorted(breakdown_df["Data Source"].unique()) + n_sources = len(unique_sources) + source_color_list = get_palette("data_sources", n_colors=n_sources) + source_color_map = dict(zip(unique_sources, source_color_list)) - with comp_tab1: - st.markdown("#### Total Features by Grid Configuration") + # Create and display stacked bar chart + fig = create_feature_count_stacked_bar(breakdown_df, source_color_map=source_color_map) + st.plotly_chart(fig, use_container_width=True) - # Get color palette for data sources - unique_sources = breakdown_df["Data Source"].unique() - n_sources = len(unique_sources) - source_colors = get_palette("data_sources", n_colors=n_sources) + # Add spatial coverage metric + n_grids = len(comparison_df) + grid_colors = get_palette("grid_configs", n_colors=n_grids) - # Create and display stacked bar chart - fig = create_feature_count_stacked_bar(breakdown_df, source_colors=source_colors) - st.plotly_chart(fig, width="stretch") + fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors) + st.plotly_chart(fig_cells, use_container_width=True) - # Add secondary metrics - col1, col2 = st.columns(2) - # Get color palette for grid configs - n_grids = len(comparison_df) - grid_colors = get_palette("grid_configs", n_colors=n_grids) + # Display full comparison table with formatting + st.markdown("#### Detailed Comparison Table") + display_df = comparison_df[ + [ + "Grid", + "Total Features", + "Data Sources", + "Inference Cells", + ] + ].copy() - with col1: - fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors) - st.plotly_chart(fig_cells, width="stretch") + # Format numbers with commas + for col in ["Total Features", "Inference Cells"]: + display_df[col] = display_df[col].apply(lambda x: f"{x:,}") - with col2: - fig_samples = create_total_samples_bar(comparison_df, grid_colors=grid_colors) - st.plotly_chart(fig_samples, width="stretch") - - with comp_tab2: - st.markdown("#### Feature Breakdown by Data Source") - st.markdown("Showing percentage contribution of each data source across all grid configurations") - - # Get color palette for data sources - unique_sources = breakdown_df["Data Source"].unique() - n_sources = len(unique_sources) - source_colors = get_palette("data_sources", n_colors=n_sources) - - # Create donut charts for each grid configuration - # Organize in a grid layout - num_grids = len(comparison_df) - cols_per_row = 3 - num_rows = (num_grids + cols_per_row - 1) // cols_per_row - - for row_idx in range(num_rows): - cols = st.columns(cols_per_row) - for col_idx in range(cols_per_row): - grid_idx = row_idx * cols_per_row + col_idx - if grid_idx < num_grids: - grid_config = comparison_df.iloc[grid_idx]["Grid"] - grid_data = breakdown_df[breakdown_df["Grid"] == grid_config] - - with cols[col_idx]: - fig = create_feature_breakdown_donut(grid_data, grid_config, source_colors=source_colors) - st.plotly_chart(fig, width="stretch") - - with comp_tab3: - st.markdown("#### Detailed Feature Count Comparison") - - # Display full comparison table with formatting - display_df = comparison_df[ - [ - "Grid", - "Total Features", - "Data Sources", - "Inference Cells", - "Total Samples", - ] - ].copy() - - # Format numbers with commas - for col in ["Total Features", "Inference Cells", "Total Samples"]: - display_df[col] = display_df[col].apply(lambda x: f"{x:,}") - - st.dataframe(display_df, hide_index=True, width="stretch") + st.dataframe(display_df, hide_index=True, use_container_width=True) @st.fragment def render_feature_count_explorer(): """Render interactive detailed configuration explorer using fragments.""" - st.markdown("### Detailed Configuration Explorer") st.markdown("Select specific grid configuration and data sources for detailed statistics") # Grid selection @@ -250,8 +174,6 @@ def render_feature_count_explorer(): # Show results if at least one member is selected if selected_members: - st.markdown("---") - # Get statistics from cache (already loaded) grid_stats = all_stats[selected_grid_config.id] @@ -301,12 +223,14 @@ def render_feature_count_explorer(): breakdown_df = pd.DataFrame(breakdown_data) - # Get color palette for data sources - n_sources = len(breakdown_df) - source_colors = get_palette("data_sources", n_colors=n_sources) + # Get all unique data sources and create color map + unique_sources = sorted(breakdown_df["Data Source"].unique()) + n_sources = len(unique_sources) + source_color_list = get_palette("data_sources", n_colors=n_sources) + source_color_map = dict(zip(unique_sources, source_color_list)) # Create and display pie chart - fig = create_feature_distribution_pie(breakdown_df, source_colors=source_colors) + fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map) st.plotly_chart(fig, width="stretch") # Show detailed table @@ -363,38 +287,82 @@ def render_feature_count_explorer(): st.info("👆 Select at least one data source to see feature statistics") -def render_feature_count_section(): - """Render the feature count section with comparison and explorer.""" - st.subheader("🔢 Feature Counts by Dataset Configuration") - - st.markdown( - """ - This visualization shows the total number of features that would be generated - for different combinations of data sources and grid configurations. - """ - ) - - # Static comparison across all grids - render_feature_count_comparison() - - st.divider() - - # Interactive explorer for detailed analysis - render_feature_count_explorer() - - def render_dataset_analysis(): """Render the dataset analysis section with sample and feature counts.""" st.header("📈 Dataset Analysis") - # Create tabs for the two different analyses - analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"]) + # Create tabs for different analysis views + analysis_tabs = st.tabs( + [ + "📊 Training Samples", + "📈 Dataset Characteristics", + "🔍 Feature Breakdown", + "⚙️ Configuration Explorer", + ] + ) with analysis_tabs[0]: + st.subheader("Training Samples by Configuration") render_sample_count_overview() with analysis_tabs[1]: - render_feature_count_section() + st.subheader("Dataset Characteristics Across Grid Configurations") + render_feature_count_comparison() + + with analysis_tabs[2]: + st.subheader("Feature Breakdown by Data Source") + # Get data from cache + all_stats = load_all_default_dataset_statistics() + comparison_df = DatasetStatistics.get_feature_count_df(all_stats) + breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats) + breakdown_df = breakdown_df.sort_values("Grid_Level_Sort") + + # Get all unique data sources and create color map + unique_sources = sorted(breakdown_df["Data Source"].unique()) + n_sources = len(unique_sources) + source_color_list = get_palette("data_sources", n_colors=n_sources) + source_color_map = dict(zip(unique_sources, source_color_list)) + + st.markdown("Showing percentage contribution of each data source across all grid configurations") + + # Sparse Resolution girds + for res in ["sparse", "low", "medium"]: + cols = st.columns(2) + with cols[0]: + grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "hex"] + for gc in grid_configs_res: + grid_display = gc.display_name + grid_data = breakdown_df[breakdown_df["Grid"] == grid_display] + fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map) + st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}") + with cols[1]: + grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "healpix"] + for gc in grid_configs_res: + grid_display = gc.display_name + grid_data = breakdown_df[breakdown_df["Grid"] == grid_display] + fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map) + st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}") + + # Create donut charts for each grid configuration + # num_grids = len(comparison_df) + # cols_per_row = 3 + # num_rows = (num_grids + cols_per_row - 1) // cols_per_row + + # for row_idx in range(num_rows): + # cols = st.columns(cols_per_row) + # for col_idx in range(cols_per_row): + # grid_idx = row_idx * cols_per_row + col_idx + # if grid_idx < num_grids: + # grid_config = comparison_df.iloc[grid_idx]["Grid"] + # grid_data = breakdown_df[breakdown_df["Grid"] == grid_config] + + # with cols[col_idx]: + # fig = create_feature_breakdown_donut(grid_data, grid_config, source_color_map=source_color_map) + # st.plotly_chart(fig, use_container_width=True) + + with analysis_tabs[3]: + st.subheader("Interactive Configuration Explorer") + render_feature_count_explorer() def render_training_results_summary(training_results): diff --git a/src/entropice/dashboard/views/training_analysis_page.py b/src/entropice/dashboard/views/training_analysis_page.py index 8609331..e73dbd8 100644 --- a/src/entropice/dashboard/views/training_analysis_page.py +++ b/src/entropice/dashboard/views/training_analysis_page.py @@ -1,10 +1,13 @@ """Training Results Analysis page: Analysis of training results and model performance.""" +from typing import cast + import streamlit as st from stopuhr import stopwatch from entropice.dashboard.plots.hyperparameter_analysis import ( render_binned_parameter_space, + render_confusion_matrix_heatmap, render_confusion_matrix_map, render_espa_binned_parameter_space, render_multi_metric_comparison, @@ -14,12 +17,12 @@ from entropice.dashboard.plots.hyperparameter_analysis import ( render_top_configurations, ) from entropice.dashboard.utils.formatters import format_metric_name -from entropice.dashboard.utils.loaders import load_all_training_results +from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results from entropice.dashboard.utils.stats import CVResultsStatistics +from entropice.utils.types import GridConfig -@st.fragment -def render_analysis_settings_sidebar(training_results): +def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str, int]: """Render sidebar for training run and analysis settings selection. Args: @@ -42,7 +45,7 @@ def render_analysis_settings_sidebar(training_results): key="training_run_select", ) - selected_result = training_options[selected_name] + selected_result = cast(TrainingResult, training_options[selected_name]) st.divider() @@ -52,22 +55,7 @@ def render_analysis_settings_sidebar(training_results): available_metrics = selected_result.available_metrics # Try to get refit metric from settings - refit_metric = selected_result.settings.refit_metric if hasattr(selected_result.settings, "refit_metric") else None - - if not refit_metric or refit_metric not in available_metrics: - # Infer from task or use first available metric - task = selected_result.settings.task - if task == "binary" and "f1" in available_metrics: - refit_metric = "f1" - elif "f1_weighted" in available_metrics: - refit_metric = "f1_weighted" - elif "accuracy" in available_metrics: - refit_metric = "accuracy" - elif available_metrics: - refit_metric = available_metrics[0] - else: - st.error("No metrics found in results.") - return None, None, None, None + refit_metric = "f1" if selected_result.settings.task == "binary" else "f1_weighted" if refit_metric in available_metrics: default_metric_idx = available_metrics.index(refit_metric) @@ -97,7 +85,7 @@ def render_analysis_settings_sidebar(training_results): return selected_result, selected_metric, refit_metric, top_n -def render_run_information(selected_result, refit_metric): +def render_run_information(selected_result: TrainingResult, refit_metric): """Render training run configuration overview. Args: @@ -107,23 +95,86 @@ def render_run_information(selected_result, refit_metric): """ st.header("📋 Run Information") - col1, col2, col3, col4, col5, col6 = st.columns(6) + grid_config = GridConfig.from_grid_level(f"{selected_result.settings.grid}{selected_result.settings.level}") # ty:ignore[invalid-argument-type] + + col1, col2, col3, col4, col5 = st.columns(5) with col1: st.metric("Task", selected_result.settings.task.capitalize()) with col2: - st.metric("Grid", selected_result.settings.grid.capitalize()) + st.metric("Target", selected_result.settings.target.capitalize()) with col3: - st.metric("Level", selected_result.settings.level) + st.metric("Grid", grid_config.display_name) with col4: st.metric("Model", selected_result.settings.model.upper()) with col5: st.metric("Trials", len(selected_result.results)) - with col6: - st.metric("CV Splits", selected_result.settings.cv_splits) st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}") +def render_test_metrics_section(selected_result: TrainingResult): + """Render test metrics overview showing final model performance. + + Args: + selected_result: The selected TrainingResult object. + + """ + st.header("🎯 Test Set Performance") + st.caption("Performance metrics on the held-out test set (best model from hyperparameter search)") + + test_metrics = selected_result.metrics + + if not test_metrics: + st.warning("No test metrics available for this training run.") + return + + # Display metrics in columns based on task type + task = selected_result.settings.task + + if task == "binary": + # Binary classification metrics + col1, col2, col3, col4, col5 = st.columns(5) + + with col1: + st.metric("Accuracy", f"{test_metrics.get('accuracy', 0):.4f}") + with col2: + st.metric("F1 Score", f"{test_metrics.get('f1', 0):.4f}") + with col3: + st.metric("Precision", f"{test_metrics.get('precision', 0):.4f}") + with col4: + st.metric("Recall", f"{test_metrics.get('recall', 0):.4f}") + with col5: + st.metric("Jaccard", f"{test_metrics.get('jaccard', 0):.4f}") + else: + # Multiclass metrics + col1, col2, col3 = st.columns(3) + + with col1: + st.metric("Accuracy", f"{test_metrics.get('accuracy', 0):.4f}") + with col2: + st.metric("F1 (Macro)", f"{test_metrics.get('f1_macro', 0):.4f}") + with col3: + st.metric("F1 (Weighted)", f"{test_metrics.get('f1_weighted', 0):.4f}") + + col4, col5, col6 = st.columns(3) + + with col4: + st.metric("Precision (Macro)", f"{test_metrics.get('precision_macro', 0):.4f}") + with col5: + st.metric("Precision (Weighted)", f"{test_metrics.get('precision_weighted', 0):.4f}") + with col6: + st.metric("Recall (Macro)", f"{test_metrics.get('recall_macro', 0):.4f}") + + col7, col8, col9 = st.columns(3) + + with col7: + st.metric("Jaccard (Micro)", f"{test_metrics.get('jaccard_micro', 0):.4f}") + with col8: + st.metric("Jaccard (Macro)", f"{test_metrics.get('jaccard_macro', 0):.4f}") + with col9: + st.metric("Jaccard (Weighted)", f"{test_metrics.get('jaccard_weighted', 0):.4f}") + + def render_cv_statistics_section(selected_result, selected_metric): """Render cross-validation statistics for selected metric. @@ -133,6 +184,7 @@ def render_cv_statistics_section(selected_result, selected_metric): """ st.header("📈 Cross-Validation Statistics") + st.caption("Performance during hyperparameter search (averaged across CV folds)") from entropice.dashboard.utils.stats import CVMetricStatistics @@ -158,6 +210,45 @@ def render_cv_statistics_section(selected_result, selected_metric): if cv_stats.mean_cv_std is not None: st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds") + # Compare with test metric if available + if selected_metric in selected_result.metrics: + test_score = selected_result.metrics[selected_metric] + st.divider() + st.subheader("CV vs Test Performance") + + col1, col2, col3 = st.columns(3) + with col1: + st.metric("Best CV Score", f"{cv_stats.best_score:.4f}") + with col2: + st.metric("Test Score", f"{test_score:.4f}") + with col3: + delta = test_score - cv_stats.best_score + delta_pct = (delta / cv_stats.best_score * 100) if cv_stats.best_score != 0 else 0 + st.metric("Difference", f"{delta:+.4f}", delta=f"{delta_pct:+.2f}%") + + if abs(delta) > cv_stats.std_score: + st.warning( + "⚠️ Test performance differs significantly from CV performance. " + "This may indicate overfitting or data distribution mismatch." + ) + + +def render_confusion_matrix_section(selected_result: TrainingResult): + """Render confusion matrix visualization and analysis. + + Args: + selected_result: The selected TrainingResult object. + + """ + st.header("🎲 Confusion Matrix") + st.caption("Detailed breakdown of predictions on the test set") + + if selected_result.confusion_matrix is None: + st.warning("No confusion matrix available for this training run.") + return + + render_confusion_matrix_heatmap(selected_result.confusion_matrix, selected_result.settings.task) + def render_parameter_space_section(selected_result, selected_metric): """Render parameter space analysis section. @@ -292,8 +383,19 @@ def render_training_analysis_page(): st.divider() + # Test Metrics Section + render_test_metrics_section(selected_result) + + st.divider() + + # Confusion Matrix Section + render_confusion_matrix_section(selected_result) + + st.divider() + # Performance Summary Section - st.header("📊 Performance Overview") + st.header("📊 CV Performance Overview") + st.caption("Summary of hyperparameter search results across all configurations") render_performance_summary(results, refit_metric) st.divider() diff --git a/src/entropice/dashboard/views/training_data_page.py b/src/entropice/dashboard/views/training_data_page.py index 5ee37de..a7d94d7 100644 --- a/src/entropice/dashboard/views/training_data_page.py +++ b/src/entropice/dashboard/views/training_data_page.py @@ -320,6 +320,7 @@ def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets): render_arcticdem_map(arcticdem_ds, targets, ensemble.grid) +@st.fragment def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets): """Render ERA5 climate data analysis. diff --git a/src/entropice/ml/dataset.py b/src/entropice/ml/dataset.py index 8959f0a..28ce0ba 100644 --- a/src/entropice/ml/dataset.py +++ b/src/entropice/ml/dataset.py @@ -17,6 +17,7 @@ import json from collections.abc import Generator from dataclasses import asdict, dataclass, field from functools import cached_property +from itertools import product from typing import Literal, TypedDict import cupy as cp @@ -30,6 +31,7 @@ import xarray as xr from rich import pretty, traceback from sklearn import set_config from sklearn.model_selection import train_test_split +from stopuhr import stopwatch import entropice.utils.paths from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task @@ -295,6 +297,30 @@ class DatasetEnsemble: temporal: Literal["yearly", "seasonal", "shoulder"], ) -> pd.DataFrame: era5 = self._read_member("ERA5-" + temporal, targets) + + if len(era5["cell_ids"]) == 0: + # No data for these cells - create empty DataFrame with expected columns + # Use the Dataset metadata to determine column structure + variables = list(era5.data_vars) + times = era5.coords["time"].to_numpy() + time_df = pd.DataFrame({"time": times}) + time_df.index = pd.DatetimeIndex(times) + tempus = _get_era5_tempus(time_df, temporal) + unique_tempus = tempus.unique() + + if "aggregations" in era5.dims: + aggs_list = era5.coords["aggregations"].to_numpy() + expected_cols = [ + f"era5_{var}_{t}_{agg}" for var, t, agg in product(variables, unique_tempus, aggs_list) + ] + else: + expected_cols = [f"era5_{var}_{t}" for var, t in product(variables, unique_tempus)] + + return pd.DataFrame( + index=targets["cell_id"].values, + columns=expected_cols, + dtype=float, + ) era5_df = era5.to_dataframe() era5_df["t"] = _get_era5_tempus(era5_df, temporal) if "aggregations" not in era5.dims: @@ -303,20 +329,50 @@ class DatasetEnsemble: else: era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"]) era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns] + # Ensure all target cell_ids are present, fill missing with NaN + era5_df = era5_df.reindex(targets["cell_id"].values, fill_value=np.nan) return era5_df def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: embeddings = self._read_member("AlphaEarth", targets)["embeddings"] + + if len(embeddings["cell_ids"]) == 0: + # No data for these cells - create empty DataFrame with expected columns + # Use the Dataset metadata to determine column structure + years = embeddings.coords["year"].to_numpy() + aggs = embeddings.coords["agg"].to_numpy() + bands = embeddings.coords["band"].to_numpy() + expected_cols = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in product(years, aggs, bands)] + return pd.DataFrame( + index=targets["cell_id"].values, + columns=expected_cols, + dtype=float, + ) embeddings_df = embeddings.to_dataframe(name="value") embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value") embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns] + # Ensure all target cell_ids are present, fill missing with NaN + embeddings_df = embeddings_df.reindex(targets["cell_id"].values, fill_value=np.nan) return embeddings_df def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: arcticdem = self._read_member("ArcticDEM", targets) + if len(arcticdem["cell_ids"]) == 0: + # No data for these cells - create empty DataFrame with expected columns + # Use the Dataset metadata to determine column structure + variables = list(arcticdem.data_vars) + aggs = arcticdem.coords["aggregations"].to_numpy() + expected_cols = [f"arcticdem_{var}_{agg}" for var, agg in product(variables, aggs)] + return pd.DataFrame( + index=targets["cell_id"].values, + columns=expected_cols, + dtype=float, + ) arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations") arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns] + # Ensure all target cell_ids are present, fill missing with NaN + arcticdem_df = arcticdem_df.reindex(targets["cell_id"].values, fill_value=np.nan) return arcticdem_df def get_stats(self) -> DatasetStats: @@ -374,33 +430,39 @@ class DatasetEnsemble: # n: no cache, o: overwrite cache, r: read cache if exists cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col) if cache_mode == "r" and cache_file.exists(): - dataset = gpd.read_parquet(cache_file) + with stopwatch("Loading dataset from cache"): + dataset = gpd.read_parquet(cache_file) print( f"Loaded cached dataset from {cache_file} with {len(dataset)} samples" f" and {len(dataset.columns)} features." ) return dataset - targets = self._read_target() - if filter_target_col is not None: - targets = targets.loc[targets[filter_target_col]] - - member_dfs = [] - for member in self.members: - if member.startswith("ERA5"): - era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] - member_dfs.append(self._prep_era5(targets, era5_agg)) - elif member == "AlphaEarth": - member_dfs.append(self._prep_embeddings(targets)) - elif member == "ArcticDEM": - member_dfs.append(self._prep_arcticdem(targets)) - else: - raise NotImplementedError(f"Member {member} not implemented.") - - dataset = targets.set_index("cell_id").join(member_dfs) + with stopwatch("Reading target"): + targets = self._read_target() + if filter_target_col is not None: + targets = targets.loc[targets[filter_target_col]] + print(f"Read and filtered target dataset. ({len(targets)} samples)") + with stopwatch("Preparing member datasets"): + member_dfs = [] + for member in self.members: + if member.startswith("ERA5"): + era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] + member_dfs.append(self._prep_era5(targets, era5_agg)) + elif member == "AlphaEarth": + member_dfs.append(self._prep_embeddings(targets)) + elif member == "ArcticDEM": + member_dfs.append(self._prep_arcticdem(targets)) + else: + raise NotImplementedError(f"Member {member} not implemented.") + print("Prepared all member datasets. Joining...") + with stopwatch("Joining datasets"): + dataset = targets.set_index("cell_id").join(member_dfs) print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") + print("Joining complete.") if cache_mode in ["o", "r"]: - dataset.to_parquet(cache_file) + with stopwatch("Saving dataset to cache"): + dataset.to_parquet(cache_file) print(f"Saved dataset to cache at {cache_file}.") return dataset diff --git a/src/entropice/ml/inference.py b/src/entropice/ml/inference.py index d285d65..99708b1 100644 --- a/src/entropice/ml/inference.py +++ b/src/entropice/ml/inference.py @@ -46,6 +46,11 @@ def predict_proba( else: cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] X_batch = batch.drop(columns=cols_to_drop).dropna() + + # Skip empty batches (all rows had NaN values) + if len(X_batch) == 0: + continue + cell_ids = X_batch.index.to_numpy() cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy() X_batch = X_batch.to_numpy(dtype="float64") diff --git a/src/entropice/ml/training.py b/src/entropice/ml/training.py index ae8c563..8a10d44 100644 --- a/src/entropice/ml/training.py +++ b/src/entropice/ml/training.py @@ -2,12 +2,14 @@ import pickle from dataclasses import asdict, dataclass +from functools import partial import cupy as cp import cyclopts import pandas as pd import toml import xarray as xr +from array_api_compat import get_namespace from cuml.ensemble import RandomForestClassifier from cuml.neighbors import KNeighborsClassifier from entropy import ESPAClassifier @@ -15,6 +17,14 @@ from rich import pretty, traceback from scipy.stats import loguniform, randint from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen from sklearn import set_config +from sklearn.metrics import ( + accuracy_score, + confusion_matrix, + f1_score, + jaccard_score, + precision_score, + recall_score, +) from sklearn.model_selection import KFold, RandomizedSearchCV from stopuhr import stopwatch from xgboost.sklearn import XGBClassifier @@ -49,6 +59,26 @@ _metrics = { } +# Compute other metrics - using predictions directly instead of re-predicting for each metric +# Use functools.partial for cleaner metric definitions with non-default parameters +_metric_functions = { + "accuracy": accuracy_score, + "recall": recall_score, + "precision": precision_score, + "f1": f1_score, + "jaccard": jaccard_score, + "recall_macro": partial(recall_score, average="macro"), + "recall_weighted": partial(recall_score, average="weighted"), + "precision_macro": partial(precision_score, average="macro"), + "precision_weighted": partial(precision_score, average="weighted"), + "f1_macro": partial(f1_score, average="macro"), + "f1_weighted": partial(f1_score, average="weighted"), + "jaccard_micro": partial(jaccard_score, average="micro"), + "jaccard_macro": partial(jaccard_score, average="macro"), + "jaccard_weighted": partial(jaccard_score, average="weighted"), +} + + @cyclopts.Parameter("*") @dataclass(frozen=True, kw_only=True) class CVSettings: @@ -200,6 +230,26 @@ def random_cv( print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}") print(f"Accuracy on test set: {test_accuracy:.3f}") + # Compute predictions on the test set + y_pred = best_estimator.predict(training_data.X.test) + labels = list(range(len(training_data.y.labels))) + xp = get_namespace(y_test) + y_test = xp.as_array(y_test) + y_pred = xp.as_array(y_pred) + labels = xp.as_array(labels) + + test_metrics = {metric: _metric_functions[metric](y_test, y_pred) for metric in metrics} + + # Get a confusion matrix + cm = confusion_matrix(y_test, y_pred, labels=labels) + label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))] + cm = xr.DataArray( + xp.as_numpy(cm), + dims=["true_label", "predicted_label"], + coords={"true_label": label_names, "predicted_label": label_names}, + name="confusion_matrix", + ) + results_dir = get_cv_results_dir( "random_search", grid=dataset_ensemble.grid, @@ -239,6 +289,17 @@ def random_cv( print(f"Storing CV results to {results_file}") results.to_parquet(results_file) + # Store the test metrics + test_metrics_file = results_dir / "test_metrics.toml" + print(f"Storing test metrics to {test_metrics_file}") + with open(test_metrics_file, "w") as f: + toml.dump({"test_metrics": test_metrics}, f) + + # Store the confusion matrix + cm_file = results_dir / "confusion_matrix.nc" + print(f"Storing confusion matrix to {cm_file}") + cm.to_netcdf(cm_file, engine="h5netcdf") + # Get the inner state of the best estimator features = training_data.X.data.columns.tolist() diff --git a/src/entropice/utils/types.py b/src/entropice/utils/types.py index 1db121c..7c47d84 100644 --- a/src/entropice/utils/types.py +++ b/src/entropice/utils/types.py @@ -30,6 +30,7 @@ class GridConfig: level: int id: GridLevel display_name: str + res: Literal["sparse", "low", "medium"] sort_key: str @classmethod @@ -46,11 +47,24 @@ class GridConfig: display_name = f"{grid.capitalize()}-{level}" + resmap: dict[str, Literal["sparse", "low", "medium"]] = { + "hex3": "sparse", + "hex4": "sparse", + "hex5": "low", + "hex6": "medium", + "healpix6": "sparse", + "healpix7": "sparse", + "healpix8": "low", + "healpix9": "low", + "healpix10": "medium", + } + return cls( grid=grid, level=level, id=grid_level, display_name=display_name, + res=resmap[grid_level], sort_key=f"{grid}_{level:02d}", ) diff --git a/tests/debug_arcticdem_batch.py b/tests/debug_arcticdem_batch.py new file mode 100644 index 0000000..6785c7d --- /dev/null +++ b/tests/debug_arcticdem_batch.py @@ -0,0 +1,33 @@ +"""Debug script to check what _prep_arcticdem returns for a batch.""" + +from entropice.ml.dataset import DatasetEnsemble + +ensemble = DatasetEnsemble( + grid="healpix", + level=10, + target="darts_mllabels", + members=["ArcticDEM"], + add_lonlat=True, + filter_target=False, +) + +# Get targets +targets = ensemble._read_target() +print(f"Total targets: {len(targets)}") + +# Get first batch of targets +batch_targets = targets.iloc[:100] +print(f"\nBatch targets: {len(batch_targets)}") +print(f"Cell IDs in batch: {batch_targets['cell_id'].values[:5]}") + +# Try to prep ArcticDEM for this batch +print("\n" + "=" * 80) +print("Calling _prep_arcticdem...") +print("=" * 80) +arcticdem_df = ensemble._prep_arcticdem(batch_targets) +print(f"\nArcticDEM DataFrame shape: {arcticdem_df.shape}") +print(f"ArcticDEM DataFrame index: {arcticdem_df.index[:5].tolist() if len(arcticdem_df) > 0 else 'EMPTY'}") +print( + f"ArcticDEM DataFrame columns ({len(arcticdem_df.columns)}): {arcticdem_df.columns[:10].tolist() if len(arcticdem_df.columns) > 0 else 'NO COLUMNS'}" +) +print(f"Number of non-NaN rows: {arcticdem_df.notna().any(axis=1).sum()}") diff --git a/tests/debug_feature_mismatch.py b/tests/debug_feature_mismatch.py new file mode 100644 index 0000000..c994af8 --- /dev/null +++ b/tests/debug_feature_mismatch.py @@ -0,0 +1,72 @@ +"""Debug script to identify feature mismatch between training and inference.""" + +from entropice.ml.dataset import DatasetEnsemble + +# Test with level 6 (the actual level used in production) +ensemble = DatasetEnsemble( + grid="healpix", + level=10, + target="darts_mllabels", + members=[ + "AlphaEarth", + "ArcticDEM", + "ERA5-yearly", + "ERA5-seasonal", + "ERA5-shoulder", + ], + add_lonlat=True, + filter_target=False, +) + +print("=" * 80) +print("Creating training dataset...") +print("=" * 80) +training_data = ensemble.create_cat_training_dataset(task="binary", device="cpu") +training_features = set(training_data.X.data.columns) +print(f"\nTraining dataset created with {len(training_features)} features") +print(f"Sample features: {sorted(list(training_features))[:10]}") + +print("\n" + "=" * 80) +print("Creating inference batch...") +print("=" * 80) +batch_generator = ensemble.create_batches(batch_size=100, cache_mode="n") +batch = next(batch_generator, None) +# for batch in batch_generator: +if batch is None: + print("ERROR: No batch created!") +else: + print(f"\nBatch created with {len(batch.columns)} columns") + print(f"Batch columns: {sorted(batch.columns)[:15]}") + + # Simulate the column dropping in predict_proba (inference.py) + cols_to_drop = ["geometry"] + if ensemble.target == "darts_mllabels": + cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")] + else: + cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] + + print(f"\nColumns to drop: {cols_to_drop}") + + inference_batch = batch.drop(columns=cols_to_drop) + inference_features = set(inference_batch.columns) + + print(f"\nInference batch after dropping has {len(inference_features)} features") + print(f"Sample features: {sorted(list(inference_features))[:10]}") + + print("\n" + "=" * 80) + print("COMPARISON") + print("=" * 80) + print(f"Training features: {len(training_features)}") + print(f"Inference features: {len(inference_features)}") + + if training_features == inference_features: + print("\n✅ SUCCESS: Features match perfectly!") + else: + print("\n❌ MISMATCH DETECTED!") + only_in_training = training_features - inference_features + only_in_inference = inference_features - training_features + + if only_in_training: + print(f"\n⚠️ Only in TRAINING ({len(only_in_training)}): {sorted(only_in_training)}") + if only_in_inference: + print(f"\n⚠️ Only in INFERENCE ({len(only_in_inference)}): {sorted(only_in_inference)}") diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..8420ab6 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,310 @@ +"""Tests for dataset.py module, specifically DatasetEnsemble class.""" + +import geopandas as gpd +import numpy as np +import pytest + +from entropice.ml.dataset import DatasetEnsemble + + +@pytest.fixture +def sample_ensemble(): + """Create a sample DatasetEnsemble for testing with minimal data.""" + return DatasetEnsemble( + grid="hex", + level=3, # Use level 3 for much faster tests + target="darts_rts", + members=["AlphaEarth"], # Use only one member for faster tests + add_lonlat=True, + filter_target="darts_has_coverage", # Filter to reduce dataset size + ) + + +@pytest.fixture +def sample_ensemble_mllabels(): + """Create a sample DatasetEnsemble with mllabels target.""" + return DatasetEnsemble( + grid="hex", + level=3, # Use level 3 for much faster tests + target="darts_mllabels", + members=["AlphaEarth"], # Use only one member for faster tests + add_lonlat=True, + filter_target="dartsml_has_coverage", # Filter to reduce dataset size + ) + + +class TestDatasetEnsemble: + """Test suite for DatasetEnsemble class.""" + + def test_initialization(self, sample_ensemble): + """Test that DatasetEnsemble initializes correctly.""" + assert sample_ensemble.grid == "hex" + assert sample_ensemble.level == 3 + assert sample_ensemble.target == "darts_rts" + assert "AlphaEarth" in sample_ensemble.members + assert sample_ensemble.add_lonlat is True + + def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels): + """Test that covcol property returns correct column name.""" + assert sample_ensemble.covcol == "darts_has_coverage" + assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage" + + def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels): + """Test that taskcol returns correct column name for different tasks.""" + assert sample_ensemble.taskcol("binary") == "darts_has_rts" + assert sample_ensemble.taskcol("count") == "darts_rts_count" + assert sample_ensemble.taskcol("density") == "darts_rts_density" + + assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts" + assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count" + assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density" + + def test_create_returns_geodataframe(self, sample_ensemble): + """Test that create() returns a GeoDataFrame.""" + dataset = sample_ensemble.create(cache_mode="n") + assert isinstance(dataset, gpd.GeoDataFrame) + + def test_create_has_expected_columns(self, sample_ensemble): + """Test that create() returns dataset with expected columns.""" + dataset = sample_ensemble.create(cache_mode="n") + + # Should have geometry column + assert "geometry" in dataset.columns + + # Should have lat/lon if add_lonlat is True + if sample_ensemble.add_lonlat: + assert "lon" in dataset.columns + assert "lat" in dataset.columns + + # Should have target columns (darts_*) + assert any(col.startswith("darts_") for col in dataset.columns) + + # Should have member data columns + assert len(dataset.columns) > 3 # More than just geometry, lat, lon + + def test_create_batches_consistency(self, sample_ensemble): + """Test that create_batches produces batches with consistent columns.""" + batch_size = 100 + batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n")) + + if len(batches) == 0: + pytest.skip("No batches created (dataset might be empty)") + + # All batches should have the same columns + first_batch_cols = set(batches[0].columns) + for i, batch in enumerate(batches[1:], start=1): + assert set(batch.columns) == first_batch_cols, ( + f"Batch {i} has different columns than batch 0. " + f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}" + ) + + def test_create_vs_create_batches_columns(self, sample_ensemble): + """Test that create() and create_batches() return datasets with the same columns.""" + full_dataset = sample_ensemble.create(cache_mode="n") + batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n")) + + if len(batches) == 0: + pytest.skip("No batches created (dataset might be empty)") + + # Columns should be identical + full_cols = set(full_dataset.columns) + batch_cols = set(batches[0].columns) + + assert full_cols == batch_cols, ( + f"Column mismatch between create() and create_batches().\n" + f"Only in create(): {full_cols - batch_cols}\n" + f"Only in create_batches(): {batch_cols - full_cols}" + ) + + def test_training_dataset_feature_columns(self, sample_ensemble): + """Test that create_cat_training_dataset creates proper feature columns.""" + training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") + + # Get the columns used for model inputs + model_input_cols = set(training_data.X.data.columns) + + # These columns should NOT be in model inputs + assert "geometry" not in model_input_cols + assert sample_ensemble.covcol not in model_input_cols + assert sample_ensemble.taskcol("binary") not in model_input_cols + + # No darts_* columns should be in model inputs + for col in model_input_cols: + assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs" + + def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels): + """Test feature columns for mllabels target.""" + training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu") + + model_input_cols = set(training_data.X.data.columns) + + # These columns should NOT be in model inputs + assert "geometry" not in model_input_cols + assert sample_ensemble_mllabels.covcol not in model_input_cols + assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols + + # No dartsml_* columns should be in model inputs + for col in model_input_cols: + assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs" + + def test_inference_vs_training_feature_consistency(self, sample_ensemble): + """Test that inference batches have the same features as training data after column dropping. + + This test simulates the workflow in training.py and inference.py to ensure + feature consistency between training and inference. + """ + # Step 1: Create training dataset (as in training.py) + # Use only a small subset by creating just one batch + training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") + training_feature_cols = set(training_data.X.data.columns) + + # Step 2: Create inference batch (as in inference.py) + # Get just the first batch to speed up test + batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n") + batch = next(batch_generator, None) + + if batch is None: + pytest.skip("No batches created (dataset might be empty)") + + # Simulate the column dropping in predict_proba + cols_to_drop = ["geometry"] + if sample_ensemble.target == "darts_mllabels": + cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")] + else: + cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] + + inference_batch = batch.drop(columns=cols_to_drop) + inference_feature_cols = set(inference_batch.columns) + + # The features should match! + assert training_feature_cols == inference_feature_cols, ( + f"Feature mismatch between training and inference!\n" + f"Only in training: {training_feature_cols - inference_feature_cols}\n" + f"Only in inference: {inference_feature_cols - training_feature_cols}\n" + f"Training features ({len(training_feature_cols)}): {sorted(training_feature_cols)}\n" + f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}" + ) + + def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels): + """Test feature consistency for mllabels target.""" + training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu") + training_feature_cols = set(training_data.X.data.columns) + + # Get just the first batch to speed up test + batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n") + batch = next(batch_generator, None) + + if batch is None: + pytest.skip("No batches created (dataset might be empty)") + + # Simulate the column dropping in predict_proba + cols_to_drop = ["geometry"] + if sample_ensemble_mllabels.target == "darts_mllabels": + cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")] + else: + cols_to_drop += [col for col in batch.columns if col.startswith("darts_")] + + inference_batch = batch.drop(columns=cols_to_drop) + inference_feature_cols = set(inference_batch.columns) + + assert training_feature_cols == inference_feature_cols, ( + f"Feature mismatch between training and inference!\n" + f"Only in training: {training_feature_cols - inference_feature_cols}\n" + f"Only in inference: {inference_feature_cols - training_feature_cols}" + ) + + def test_all_tasks_feature_consistency(self, sample_ensemble): + """Test that all task types produce consistent features.""" + tasks = ["binary", "count", "density"] + feature_sets = {} + + for task in tasks: + training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu") + feature_sets[task] = set(training_data.X.data.columns) + + # All tasks should have the same features + binary_features = feature_sets["binary"] + for task, features in feature_sets.items(): + assert features == binary_features, ( + f"Task '{task}' has different features than 'binary'.\n" + f"Difference: {features.symmetric_difference(binary_features)}" + ) + + def test_training_dataset_shapes(self, sample_ensemble): + """Test that training dataset has correct shapes.""" + training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") + + n_features = len(training_data.X.data.columns) + n_samples_train = training_data.X.train.shape[0] + n_samples_test = training_data.X.test.shape[0] + + # Check X shapes + assert training_data.X.train.shape == (n_samples_train, n_features) + assert training_data.X.test.shape == (n_samples_test, n_features) + + # Check y shapes + assert training_data.y.train.shape == (n_samples_train,) + assert training_data.y.test.shape == (n_samples_test,) + + # Check that train + test = total samples + assert len(training_data.dataset) == n_samples_train + n_samples_test + + def test_no_nan_in_training_features(self, sample_ensemble): + """Test that training features don't contain NaN values.""" + training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu") + + # Convert to numpy for checking (handles both numpy and cupy arrays) + X_train = np.asarray(training_data.X.train) + X_test = np.asarray(training_data.X.test) + + assert not np.isnan(X_train).any(), "Training features contain NaN values" + assert not np.isnan(X_test).any(), "Test features contain NaN values" + + def test_batch_coverage(self, sample_ensemble): + """Test that batches cover all data without duplication.""" + full_dataset = sample_ensemble.create(cache_mode="n") + batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n")) + + if len(batches) == 0: + pytest.skip("No batches created (dataset might be empty)") + + # Collect all cell_ids from batches + batch_cell_ids = set() + for batch in batches: + batch_ids = set(batch.index) + # Check for duplicates across batches + overlap = batch_cell_ids.intersection(batch_ids) + assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches" + batch_cell_ids.update(batch_ids) + + # Check that all cell_ids from full dataset are in batches + full_cell_ids = set(full_dataset.index) + assert batch_cell_ids == full_cell_ids, ( + f"Batch coverage mismatch.\n" + f"Missing from batches: {full_cell_ids - batch_cell_ids}\n" + f"Extra in batches: {batch_cell_ids - full_cell_ids}" + ) + + +class TestDatasetEnsembleEdgeCases: + """Test edge cases and error handling.""" + + def test_invalid_task_raises_error(self, sample_ensemble): + """Test that invalid task raises ValueError.""" + with pytest.raises(ValueError, match="Invalid task"): + sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore + + def test_stats_method(self, sample_ensemble): + """Test that get_stats returns expected structure.""" + stats = sample_ensemble.get_stats() + + assert "target" in stats + assert "num_target_samples" in stats + assert "members" in stats + assert "total_features" in stats + + # Check that members dict contains info for each member + for member in sample_ensemble.members: + assert member in stats["members"] + assert "variables" in stats["members"][member] + assert "num_features" in stats["members"][member]