Enhance training analysis page with test metrics and confusion matrix

- Added a section to display test metrics for model performance on the held-out test set.
- Implemented confusion matrix visualization to analyze prediction breakdown.
- Refactored sidebar settings to streamline metric selection and improve user experience.
- Updated cross-validation statistics to compare CV performance with test metrics.
- Enhanced DatasetEnsemble methods to handle empty data scenarios gracefully.
- Introduced debug scripts to assist in identifying feature mismatches and validating dataset preparation.
- Added comprehensive tests for DatasetEnsemble to ensure feature consistency and correct behavior across various scenarios.
This commit is contained in:
Tobias Hölzer 2026-01-07 15:56:02 +01:00
parent 4fecac535c
commit c92e856c55
23 changed files with 1845 additions and 484 deletions

223
pixi.lock generated
View file

@ -473,16 +473,16 @@ environments:
- pypi: https://files.pythonhosted.org/packages/69/ce/68d6e31f0a75a5cccc03535e47434c0ca4be37fe950e93117e455cbc362c/antimeridian-0.4.5-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/5b/03/c17464bbf682ea87e7e3de2ddc63395e359a78ae9c01f55fc78759ecbd79/anywidget-0.9.21-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/1d/05/2709750ddb088eb2fc5053ba214b4f54334d15d4cb28217e2956b5507bac/array_api_extra-0.9.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/df/5d/493b1b5528ab5072feae30821ff3a07b7a0474213d548efb1fdf135f85c1/array_api_compat-1.13.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/21/2b/bfa1cfe370dd4ed51f834f2c6ad93b7f6263b83615ab96ad91094cc98ec6/array_api_extra-0.9.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/fb/1f/2903ef412cb82ba1f2211692f7339fd7c5aeb2764f2a97f0b6a9a18bbf52/arro3_compute-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/31/4a/72dc383d1a0d14f1d453e334e3461e229762edb1bf3f75b3ab977e9386ed/arro3_core-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/1b/df/2a5a1306dc1699b51b02c1c38c55f3564a8c4f84087c23c61e7e7ae37dfa/arro3_io-0.6.5-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/c3/1c/f06ad85180e7dd9855aa5ede901bfc2be858d7bee17d4e978a14c0ecec14/astropy-7.2.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/57/61/2d06c08f022c9b617b79f6c55d88e596c1795a1d211e6bf584ac4b9e9506/astropy_iers_data-0.2026.1.5.0.43.43-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ee/34/a9914e676971a13d6cc671b1ed172f9804b50a3a80a143ff196e52f4c7ee/azure_core-1.37.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/d8/3a/6ef2047a072e54e1142718d433d50e9514c999a58f51abfff7902f3a72f8/azure_storage_blob-12.28.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/96/9a/663251dfb35aaddcbdbef78802ea5a9d3fad9d5fadde8774eacd9e1bfbb7/boost_histogram-1.6.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/3c/56/f47a80254ed4991cce9a2f6d8ae8aafbc8df1c3270e966b2927289e5a12f/boto3-1.41.5-py3-none-any.whl
@ -494,12 +494,11 @@ environments:
- pypi: https://files.pythonhosted.org/packages/27/27/6414b1b7e5e151300c54e28ad1cf3e3b34fe66dc3256a989b031166b1ba3/cdshealpix-0.7.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/a3/8f/c42a98f933022c7de00142526c9b6b7429fdcd0fc66c952b4ebbf0ff3b7f/cf_xarray-0.10.10-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ba/08/52f06ff2f04d376f9cd2c211aefcf2b37f1978e43289341f362fc99f6a0e/cftime-1.6.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/3d/9a/2abecb28ae875e39c8cad711eb1186d8d14eab564705325e77e4e6ab9ae5/click_plugins-1.1.1.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/73/86/43fa9f15c5b9fb6e82620428827cd3c284aa933431405d1bcf5231ae3d3e/cligj-0.7.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/fa/25/0be9314cd72fe2ee2ef89ceb1f438bc156428a12177d684040456eee4a56/cupy_xarray-0.1.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/5b/0eceb9a5990de9025733a0d212ca43649ba9facd58b8552b6bf93c11439d/cyclopts-4.4.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/25/3e/e27078370414ef35fafad2c06d182110073daaeb5d3bf734b0b1eeefe452/debugpy-1.8.19-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl
@ -522,10 +521,10 @@ environments:
- pypi: https://files.pythonhosted.org/packages/31/b3/802576f2ea5dcb48501bb162e4c7b7b3ca5654a42b2c968ef98a797a4c79/geographiclib-2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e5/15/cf2a69ade4b194aa524ac75112d5caac37414b20a3a03e6865dfe0bd1539/geopy-2.4.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ed/d4/90197b416cb61cefd316964fd9e7bd8324bcbafabf40eef14a9f20b81974/google_api_core-2.28.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/96/58/c1e716be1b055b504d80db2c8413f6c6a890a6ae218a65f178b63bc30356/google_api_python_client-2.187.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c6/97/451d55e05487a5cd6279a01a7e34921858b16f7dc8aa38a2c684743cd2b3/google_auth-2.45.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/db/18/79e9008530b79527e0d5f79e7eef08d3b179b7f851cfd3a2f27822fbdfa9/google_auth-2.47.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/99/d5/3c97526c8796d3caf5f4b3bed2b05e8a7102326f00a334e7a438237f3b22/google_auth_httplib2-0.3.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/89/20/bfa472e327c8edee00f04beecc80baeddd2ab33ee0e86fd7654da49d45e9/google_cloud_core-2.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/2d/80/6e5c7c83cea15ed4dfc4843b9df9db0716bc551ac938f7b5dd18a72bd5e4/google_cloud_storage-3.7.0-py3-none-any.whl
@ -537,13 +536,14 @@ environments:
- pypi: https://files.pythonhosted.org/packages/d6/49/1f35189c1ca136b2f041b72402f2eb718bdcb435d9e88729fe6f6909c45d/h5netcdf-1.7.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/d9/69/4402ea66272dacc10b298cca18ed73e1c0791ff2ae9ed218d3859f9698ac/h5py-3.15.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/8c/a2/0d269db0f6163be503775dc8b6a6fa15820cc9fdc866f6ba608d86b721f2/httplib2-0.31.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/94/56/c5e8db63ba0e27b310a0b4c384da555b361741e7d186044d31f400c0419e/icechunk-1.1.14-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/8c/d7/db466e07a21553441adbf915f0913a3f8fecece364cacb2392f11be267be/icechunk-1.1.15-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4c/0f/b66d63d4a5426c09005d3713b056e634e00e69788fdc88d1ffe40e5b7654/ipycytoscape-1.3.3-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ca/d3/642a6dc3db8ea558a9b5fbc83815b197861868dc98f98a789b85c7660670/ipyevents-2.0.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/00/60/249e3444fcd9c833704741769981cd02fe2c7ce94126b1394e7a3b26e543/ipyfilechooser-0.6.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/49/69/e9858f2c0b99bf9f036348d1c84b8026f438bb6875effe6a9bcd9883dada/ipyleaflet-0.20.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/86/92/162cfaee4ccf370465c5af1ce36a9eacec1becb552f2033bb3584e6f640a/ipython-9.9.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/56/6d/0d9848617b9f753b87f214f1c682592f7ca42de085f564352f10f0843026/ipywidgets-8.1.8-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/15/aa/0aca39a37d3c7eb941ba736ede56d689e7be91cab5d9ca846bde3999eba6/isodate-0.7.2-py3-none-any.whl
@ -558,11 +558,11 @@ environments:
- pypi: https://files.pythonhosted.org/packages/93/cf/be4e93afbfa0def2cd6fac9302071db0bd6d0617999ecbf53f92b9398de3/multiurl-0.3.7-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/97/1a/78b19893197ed7525edfa7f124a461626541e82aec694a468ba97755c24e/netcdf4-1.7.3-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7b/7a/a8d32501bb95ecff342004a674720164f95ad616f269450b3bc13dc88ae3/netcdf4-1.7.4-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/ae/d3/ff8f1b9968aa4dcd1da1880322ed492314cc920998182e549b586c895a17/numbagg-0.9.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c4/e6/d359fdd37498e74d26a167f7a51e54542e642ea47181eb4e643a69a066c3/numcodecs-0.16.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/b0/e0/760e73c111193db5ca37712a148e4807d1b0c60302ab31e4ada6528ca34d/numpy_groupies-0.11.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/4a/4e/44dbb46b3d1b0ec61afda8e84837870f2f9ace33c564317d59b70bc19d3e/nvidia_nccl_cu12-2.28.9-py3-none-manylinux_2_18_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/23/2d/609d0392d992259c6dc39881688a7fc13b1397a668bc360fbd68d1396f85/nvidia_nccl_cu12-2.29.2-py3-none-manylinux_2_18_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/99/e2/311fb383d9534eef7bfbe858fad931b6e3dbe85843c50592f50063c3bc83/odc_geo-0.4.10-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/84/99/6636f7097a5e461d560317024522279f52931b5a52c8caa0755a14d5f1fd/odc_loader-0.6.0-py3-none-any.whl
@ -572,11 +572,12 @@ environments:
- pypi: https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e7/c3/3031c931098de393393e1f93a38dc9ed6805d86bb801acc3cf2d5bd1e6b7/plotly-6.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/cd/24/3b7a0818484df9c28172857af32c2397b6d8fcd99d9468bd4684f98ebf0a/proto_plus-1.27.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/56/13/333b8f421738f149d4fe5e49553bc2a2ab75235486259f689b4b91f96cec/protobuf-6.33.2-cp39-abi3-manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/ff/7b/e9a6fa461ef266c5a23485004934b8f08a2a8ddc447802161ea56d9837dd/psygnal-0.15.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/2d/4f/3593e5adb88a188c798604aed95fbc1479f30230e7f51e8f2c770e6a3832/psygnal-0.15.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl
@ -587,8 +588,9 @@ environments:
- pypi: https://files.pythonhosted.org/packages/82/06/cad54e8ce758bd836ee5411691cbd49efeb9cc611b374670fce299519334/pyshp-3.0.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9f/86/3ec01436c6235a23a80e978b261a87481c1acaf626a5c618e9edac30e5e1/pystac-1.14.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/5d/d2/5f6367b14c9f250d1a6725d18bd1e9584f5ab1587e292f3a847e59189598/pystac_client-0.9.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/88/ae/baf3a8057d8129896a7e02619df43ea0d918fc5b2bb66eb6e2470595fbac/python_box-7.3.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7b/84/66c0d9cca2a09074ec2ce6fffa87709ca51b0d197ae742d835e841bac660/rasterio-1.4.4-cp313-cp313-manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/48/4a/1af9aa9810fb30668568f2c4dd3eec2412c8e9762b69201d971c509b295e/rasterio-1.5.0-cp313-cp313-manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/f2/98/7e6d147fd16a10a5f821db6e25f192265d6ecca3d82957a4fdd592cad49c/ratelim-0.1.6-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/34/83/a485250bc09db55e4b4389d99e583fac871ceeaaa4620b67a31d8db95ef5/rechunker-0.5.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/13/2f/b4530fbf948867702d0a3f27de4a6aab1d156f406d72852ab902c4d04de9/rich_rst-1.3.2-py3-none-any.whl
@ -608,16 +610,16 @@ environments:
- pypi: https://files.pythonhosted.org/packages/c0/95/6b7873f0267973ebd55ba9cd33a690b35a116f2779901ef6185a0e21864d/streamlit-1.52.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/72/35/d3cdab8cff94971714f866181abb1aa84ad976f6e7b6218a0499197465e4/streamlit_folium-0.25.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b5/fc/5e2988590ff2e0128eea6446806c904445a44e17256c67141573ea16b5a5/textual-6.11.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/84/38/47fab2a5fad163ca4851f7a20eb2442491cc63bf2756ec4ef161bc1461dd/textual-7.0.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8d/c0/fdf9d3ee103ce66a55f0532835ad5e154226c5222423c6636ba049dc42fc/traittypes-0.2.3-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/94/fc/1d34ec891900d9337169ff9f8252fcaa633ae5c4d36b67effd849ed4f9ac/ty-0.0.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/95/20/92e3083b0e854943015bc8a7866e284ead9efadf9bf6809e6fce3b7ded61/ultraplot-1.66.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/43/6c/b26831b890b37c09882f6406efd31441c8e512bf1efbc967b9d867c5e02b/ultraplot-1.70.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a9/99/3ae339466c9183ea5b8ae87b34c0b897eda475d2aec2307cae60e5cd4f29/uritemplate-4.2.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/e6/9f/ca52771fe972e0dcc5167fedb609940e01516066938ff2ee28b273ae4f29/vega_datasets-0.9.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/04/d5/81d1403788f072e7d0e2b2fe539a0ae4410f27886ff52df094e5348c99ea/vegafusion-2.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/6f/61/dc6f4a38cf1b8699f64c57d7f021ca42c39bfe782d8a6eaefb7e8418e925/vl_convert_python-1.9.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl
@ -884,10 +886,10 @@ packages:
- pkg:pypi/argon2-cffi-bindings?source=hash-mapping
size: 35943
timestamp: 1762509452935
- pypi: https://files.pythonhosted.org/packages/e0/b1/0542e0cab6f49f151a2d7a42400f84f706fc0b64e85dc1f56708b2e9fd37/array_api_compat-1.12.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/df/5d/493b1b5528ab5072feae30821ff3a07b7a0474213d548efb1fdf135f85c1/array_api_compat-1.13.0-py3-none-any.whl
name: array-api-compat
version: 1.12.0
sha256: a0b4795b6944a9507fde54679f9350e2ad2b1e2acf4a2408a098cdc27f890a8b
version: 1.13.0
sha256: c15026a0ddec42815383f07da285472e1b1ff2e632eb7afbcfe9b08fcbad9bf1
requires_dist:
- cupy ; extra == 'cupy'
- dask>=2024.9.0 ; extra == 'dask'
@ -905,16 +907,16 @@ packages:
- array-api-strict ; extra == 'dev'
- dask[array]>=2024.9.0 ; extra == 'dev'
- jax[cpu] ; extra == 'dev'
- ndonnx ; extra == 'dev'
- numpy>=1.22 ; extra == 'dev'
- pytest ; extra == 'dev'
- torch ; extra == 'dev'
- sparse>=0.15.1 ; extra == 'dev'
- ndonnx ; extra == 'dev'
requires_python: '>=3.10'
- pypi: https://files.pythonhosted.org/packages/1d/05/2709750ddb088eb2fc5053ba214b4f54334d15d4cb28217e2956b5507bac/array_api_extra-0.9.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/21/2b/bfa1cfe370dd4ed51f834f2c6ad93b7f6263b83615ab96ad91094cc98ec6/array_api_extra-0.9.2-py3-none-any.whl
name: array-api-extra
version: 0.9.1
sha256: 78b3e6605d1cdc9a66bb49e340e1bb620f045f1809a4e146d74500c3cb813b74
version: 0.9.2
sha256: d0643a9a4e981746057649accad068ca0fe4066d890f6a95d8b4cd5131b3b661
requires_dist:
- array-api-compat>=1.12.0,<2
requires_python: '>=3.10'
@ -1026,10 +1028,10 @@ packages:
- astropy[dev] ; extra == 'dev-all'
- astropy[test-all] ; extra == 'dev-all'
requires_python: '>=3.11'
- pypi: https://files.pythonhosted.org/packages/1f/07/50501947849e780cb5580ebcd7af08c14d431640562e18a8ac2b055c90ec/astropy_iers_data-0.2025.12.22.0.40.30-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/57/61/2d06c08f022c9b617b79f6c55d88e596c1795a1d211e6bf584ac4b9e9506/astropy_iers_data-0.2026.1.5.0.43.43-py3-none-any.whl
name: astropy-iers-data
version: 0.2025.12.22.0.40.30
sha256: 2fbc71988d96aa29566667c6568a2bc5ca00748174b1f8ac3e9f7b09d4c27cac
version: 0.2026.1.5.0.43.43
sha256: fe2c35e9abc99142083d717ea76bf7bde373dc12e502aaeced28ae4ff9bfc345
requires_dist:
- pytest ; extra == 'docs'
- hypothesis ; extra == 'test'
@ -1298,10 +1300,10 @@ packages:
purls: []
size: 249684
timestamp: 1761066654684
- pypi: https://files.pythonhosted.org/packages/3d/9e/1c90a122ea6180e8c72eb7294adc92531b0e08eb3d2324c2ba70d37f4802/azure_storage_blob-12.27.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/d8/3a/6ef2047a072e54e1142718d433d50e9514c999a58f51abfff7902f3a72f8/azure_storage_blob-12.28.0-py3-none-any.whl
name: azure-storage-blob
version: 12.27.1
sha256: 65d1e25a4628b7b6acd20ff7902d8da5b4fde8e46e19c8f6d213a3abc3ece272
version: 12.28.0
sha256: 00fb1db28bf6a7b7ecaa48e3b1d5c83bfadacc5a678b77826081304bd87d6461
requires_dist:
- azure-core>=1.30.0
- cryptography>=2.1.4
@ -1772,16 +1774,6 @@ packages:
- pkg:pypi/click?source=hash-mapping
size: 97676
timestamp: 1764518652276
- pypi: https://files.pythonhosted.org/packages/3d/9a/2abecb28ae875e39c8cad711eb1186d8d14eab564705325e77e4e6ab9ae5/click_plugins-1.1.1.2-py2.py3-none-any.whl
name: click-plugins
version: 1.1.1.2
sha256: 008d65743833ffc1f5417bf0e78e8d2c23aab04d9745ba817bd3e71b0feb6aa6
requires_dist:
- click>=4.0
- pytest>=3.6 ; extra == 'dev'
- pytest-cov ; extra == 'dev'
- wheel ; extra == 'dev'
- coveralls ; extra == 'dev'
- pypi: https://files.pythonhosted.org/packages/73/86/43fa9f15c5b9fb6e82620428827cd3c284aa933431405d1bcf5231ae3d3e/cligj-0.7.2-py3-none-any.whl
name: cligj
version: 0.7.2
@ -2517,10 +2509,10 @@ packages:
- pkg:pypi/cycler?source=hash-mapping
size: 14778
timestamp: 1764466758386
- pypi: https://files.pythonhosted.org/packages/8d/05/8efadba80e1296526e69c1dceba8b0f0bc3756e8d69f6ed9b0e647cf3169/cyclopts-4.4.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/5b/0eceb9a5990de9025733a0d212ca43649ba9facd58b8552b6bf93c11439d/cyclopts-4.4.4-py3-none-any.whl
name: cyclopts
version: 4.4.1
sha256: 67500e9fde90f335fddbf9c452d2e7c4f58209dffe52e7abb1e272796a963bde
version: 4.4.4
sha256: 316f798fe2f2a30cb70e7140cfde2a46617bfbb575d31bbfdc0b2410a447bd83
requires_dist:
- attrs>=23.1.0
- docstring-parser>=0.15,<4.0
@ -2869,7 +2861,7 @@ packages:
- pypi: ./
name: entropice
version: 0.1.0
sha256: cb0c27d2c23c64d7533c03e380cf55c40e82a4d52a0392a829fad06a4ca93736
sha256: c15584d2588d1f67ff2c69d6d3afa461fd1b9571d423497221dcb295dd7b1514
requires_dist:
- aiohttp>=3.12.11
- bokeh>=3.7.3
@ -2932,6 +2924,7 @@ packages:
- ty>=0.0.2,<0.0.3
- ruff>=0.14.9,<0.15
- pandas-stubs>=2.3.3.251201,<3
- pytest>=9.0.2,<10
requires_python: '>=3.13,<3.14'
- pypi: git+ssh://git@forgejo.tobiashoelzer.de:22222/tobias/entropy.git#9ca1bdf4afc4ac9b0ea29ebbc060ffecb5cffcf7
name: entropy
@ -3422,17 +3415,17 @@ packages:
requires_dist:
- smmap>=3.0.1,<6
requires_python: '>=3.7'
- pypi: https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl
name: gitpython
version: 3.1.45
sha256: 8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77
version: 3.1.46
sha256: 79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058
requires_dist:
- gitdb>=4.0.1,<5
- typing-extensions>=3.10.0.2 ; python_full_version < '3.10'
- coverage[toml] ; extra == 'test'
- ddt>=1.1.1,!=1.4.3 ; extra == 'test'
- mock ; python_full_version < '3.8' and extra == 'test'
- mypy ; extra == 'test'
- mypy==1.18.2 ; python_full_version >= '3.9' and extra == 'test'
- pre-commit ; extra == 'test'
- pytest>=7.3.1 ; extra == 'test'
- pytest-cov ; extra == 'test'
@ -3516,42 +3509,35 @@ packages:
- google-api-core>=1.31.5,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0
- uritemplate>=3.0.1,<5
requires_python: '>=3.7'
- pypi: https://files.pythonhosted.org/packages/c6/97/451d55e05487a5cd6279a01a7e34921858b16f7dc8aa38a2c684743cd2b3/google_auth-2.45.0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/db/18/79e9008530b79527e0d5f79e7eef08d3b179b7f851cfd3a2f27822fbdfa9/google_auth-2.47.0-py3-none-any.whl
name: google-auth
version: 2.45.0
sha256: 82344e86dc00410ef5382d99be677c6043d72e502b625aa4f4afa0bdacca0f36
version: 2.47.0
sha256: c516d68336bfde7cf0da26aab674a36fedcf04b37ac4edd59c597178760c3498
requires_dist:
- cachetools>=2.0.0,<7.0
- pyasn1-modules>=0.2.1
- rsa>=3.1.4,<5
- cryptography>=38.0.3 ; extra == 'cryptography'
- cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'cryptography'
- aiohttp>=3.6.2,<4.0.0 ; extra == 'aiohttp'
- requests>=2.20.0,<3.0.0 ; extra == 'aiohttp'
- cryptography ; extra == 'enterprise-cert'
- pyopenssl ; extra == 'enterprise-cert'
- pyopenssl>=20.0.0 ; extra == 'pyopenssl'
- cryptography>=38.0.3 ; extra == 'pyopenssl'
- cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'pyopenssl'
- pyjwt>=2.0 ; extra == 'pyjwt'
- cryptography>=38.0.3 ; extra == 'pyjwt'
- cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'pyjwt'
- pyu2f>=0.1.5 ; extra == 'reauth'
- requests>=2.20.0,<3.0.0 ; extra == 'requests'
- grpcio ; extra == 'testing'
- flask ; extra == 'testing'
- freezegun ; extra == 'testing'
- mock ; extra == 'testing'
- oauth2client ; extra == 'testing'
- pyjwt>=2.0 ; extra == 'testing'
- cryptography>=38.0.3 ; extra == 'testing'
- cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'testing'
- pytest ; extra == 'testing'
- pytest-cov ; extra == 'testing'
- pytest-localserver ; extra == 'testing'
- pyopenssl>=20.0.0 ; extra == 'testing'
- cryptography>=38.0.3 ; extra == 'testing'
- cryptography<39.0.0 ; python_full_version < '3.8' and extra == 'testing'
- pyu2f>=0.1.5 ; extra == 'testing'
- responses ; extra == 'testing'
- urllib3 ; extra == 'testing'
@ -3564,7 +3550,7 @@ packages:
- aiohttp<3.10.0 ; extra == 'testing'
- urllib3 ; extra == 'urllib3'
- packaging ; extra == 'urllib3'
requires_python: '>=3.7'
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/99/d5/3c97526c8796d3caf5f4b3bed2b05e8a7102326f00a334e7a438237f3b22/google_auth_httplib2-0.3.0-py3-none-any.whl
name: google-auth-httplib2
version: 0.3.0
@ -3767,10 +3753,10 @@ packages:
- pkg:pypi/hyperframe?source=hash-mapping
size: 17397
timestamp: 1737618427549
- pypi: https://files.pythonhosted.org/packages/94/56/c5e8db63ba0e27b310a0b4c384da555b361741e7d186044d31f400c0419e/icechunk-1.1.14-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/8c/d7/db466e07a21553441adbf915f0913a3f8fecece364cacb2392f11be267be/icechunk-1.1.15-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
name: icechunk
version: 1.1.14
sha256: adb01a0275144c58f741b5402e658930326e86f7b389e879065e01625c021f7c
version: 1.1.15
sha256: c9e0cc3c8623a48861470553dbb8b0f1e86600989f597ce41ecf47568d8d099d
requires_dist:
- zarr>=3,!=3.0.3
- boto3 ; extra == 'test'
@ -3926,6 +3912,11 @@ packages:
- pkg:pypi/importlib-metadata?source=hash-mapping
size: 34641
timestamp: 1747934053147
- pypi: https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl
name: iniconfig
version: 2.3.0
sha256: f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12
requires_python: '>=3.10'
- pypi: https://files.pythonhosted.org/packages/4c/0f/b66d63d4a5426c09005d3713b056e634e00e69788fdc88d1ffe40e5b7654/ipycytoscape-1.3.3-py2.py3-none-any.whl
name: ipycytoscape
version: 1.3.3
@ -4024,10 +4015,10 @@ packages:
- traittypes>=0.2.1,<3
- xyzservices>=2021.8.1
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/86/92/162cfaee4ccf370465c5af1ce36a9eacec1becb552f2033bb3584e6f640a/ipython-9.9.0-py3-none-any.whl
name: ipython
version: 9.8.0
sha256: ebe6d1d58d7d988fbf23ff8ff6d8e1622cfdb194daf4b7b73b792c4ec3b85385
version: 9.9.0
sha256: b457fe9165df2b84e8ec909a97abcf2ed88f565970efba16b1f7229c283d252b
requires_dist:
- colorama>=0.4.4 ; sys_platform == 'win32'
- decorator>=4.3.2
@ -4067,7 +4058,8 @@ packages:
- pandas>2.1 ; extra == 'test-extra'
- trio>=0.1.0 ; extra == 'test-extra'
- matplotlib>3.9 ; extra == 'matplotlib'
- ipython[doc,matplotlib,test,test-extra] ; extra == 'all'
- ipython[doc,matplotlib,terminal,test,test-extra] ; extra == 'all'
- argcomplete>=3.0 ; extra == 'all'
requires_python: '>=3.11'
- pypi: https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl
name: ipython-pygments-lexers
@ -6859,14 +6851,15 @@ packages:
version: 1.6.0
sha256: 87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c
requires_python: '>=3.5'
- pypi: https://files.pythonhosted.org/packages/97/1a/78b19893197ed7525edfa7f124a461626541e82aec694a468ba97755c24e/netcdf4-1.7.3-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/7b/7a/a8d32501bb95ecff342004a674720164f95ad616f269450b3bc13dc88ae3/netcdf4-1.7.4-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
name: netcdf4
version: 1.7.3
sha256: 0c764ba6f6a1421cab5496097e8a1c4d2e36be2a04880dfd288bb61b348c217e
version: 1.7.4
sha256: a72c9f58767779ec14cb7451c3b56bdd8fdc027a792fac2062b14e090c5617f3
requires_dist:
- cftime
- certifi
- numpy
- numpy>=2.3.0 ; platform_machine == 'ARM64' and sys_platform == 'win32'
- numpy>=1.21.2 ; platform_machine != 'ARM64' or sys_platform != 'win32'
- cython ; extra == 'tests'
- packaging ; extra == 'tests'
- pytest ; extra == 'tests'
@ -7045,10 +7038,10 @@ packages:
- pkg:pypi/nvidia-ml-py?source=hash-mapping
size: 48971
timestamp: 1765209768013
- pypi: https://files.pythonhosted.org/packages/4a/4e/44dbb46b3d1b0ec61afda8e84837870f2f9ace33c564317d59b70bc19d3e/nvidia_nccl_cu12-2.28.9-py3-none-manylinux_2_18_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/23/2d/609d0392d992259c6dc39881688a7fc13b1397a668bc360fbd68d1396f85/nvidia_nccl_cu12-2.29.2-py3-none-manylinux_2_18_x86_64.whl
name: nvidia-nccl-cu12
version: 2.28.9
sha256: 485776daa8447da5da39681af455aa3b2c2586ddcf4af8772495e7c532c7e5ab
version: 2.29.2
sha256: 3a9a0bf4142126e0d0ed99ec202579bef8d007601f9fab75af60b10324666b12
requires_python: '>=3'
- conda: https://conda.anaconda.org/conda-forge/linux-64/nvtx-0.2.14-py313h07c4f96_0.conda
sha256: 9341cb332428242ab938c5fc202008c12430ec43b8b83511d327f14bf8fd6d96
@ -7501,6 +7494,17 @@ packages:
- xarray ; extra == 'dev-optional'
- plotly[dev-optional] ; extra == 'dev'
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl
name: pluggy
version: 1.6.0
sha256: e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746
requires_dist:
- pre-commit ; extra == 'dev'
- tox ; extra == 'dev'
- pytest ; extra == 'testing'
- pytest-benchmark ; extra == 'testing'
- coverage ; extra == 'testing'
requires_python: '>=3.9'
- conda: https://conda.anaconda.org/conda-forge/noarch/polars-1.34.0-pyh6a1acc5_0.conda
sha256: 7e8bb10f4373202a0be760d9ac74f92c5e7e6095251180642678a8f57f10c58a
md5: d398dbcb3312bbebc2b2f3dbb98b4262
@ -7652,10 +7656,10 @@ packages:
- pkg:pypi/psutil?source=hash-mapping
size: 501735
timestamp: 1762092897061
- pypi: https://files.pythonhosted.org/packages/ff/7b/e9a6fa461ef266c5a23485004934b8f08a2a8ddc447802161ea56d9837dd/psygnal-0.15.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/2d/4f/3593e5adb88a188c798604aed95fbc1479f30230e7f51e8f2c770e6a3832/psygnal-0.15.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
name: psygnal
version: 0.15.0
sha256: a0172efeb861280bca05673989a4df21624f44344eff20b873d8c9d0edc01350
version: 0.15.1
sha256: e9fca977f5335deea39aed22e31d9795983e4f243e59a7d3c4105793adb7693d
requires_dist:
- wrapt ; extra == 'proxy'
- pydantic ; extra == 'pydantic'
@ -8013,6 +8017,26 @@ packages:
- pystac[validation]>=1.10.0
- python-dateutil>=2.8.2
requires_python: '>=3.10'
- pypi: https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl
name: pytest
version: 9.0.2
sha256: 711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b
requires_dist:
- colorama>=0.4 ; sys_platform == 'win32'
- exceptiongroup>=1 ; python_full_version < '3.11'
- iniconfig>=1.0.1
- packaging>=22
- pluggy>=1.5,<2
- pygments>=2.7.2
- tomli>=1 ; python_full_version < '3.11'
- argcomplete ; extra == 'dev'
- attrs>=19.2 ; extra == 'dev'
- hypothesis>=3.56 ; extra == 'dev'
- mock ; extra == 'dev'
- requests ; extra == 'dev'
- setuptools ; extra == 'dev'
- xmlschema ; extra == 'dev'
requires_python: '>=3.10'
- conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.11-hc97d973_100_cp313.conda
build_number: 100
sha256: 9cf014cf28e93ee242bacfbf664e8b45ae06e50b04291e640abeaeb0cba0364c
@ -8383,19 +8407,19 @@ packages:
license: LicenseRef-Custom
size: 6143
timestamp: 1765438804958
- pypi: https://files.pythonhosted.org/packages/7b/84/66c0d9cca2a09074ec2ce6fffa87709ca51b0d197ae742d835e841bac660/rasterio-1.4.4-cp313-cp313-manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/48/4a/1af9aa9810fb30668568f2c4dd3eec2412c8e9762b69201d971c509b295e/rasterio-1.5.0-cp313-cp313-manylinux_2_28_x86_64.whl
name: rasterio
version: 1.4.4
sha256: c072450caa96428b1218b030500bb908fd6f09bc013a88969ff81a124b6a112a
version: 1.5.0
sha256: 08a7580cbb9b3bd320bdf827e10c9b2424d0df066d8eef6f2feb37e154ce0c17
requires_dist:
- affine
- attrs
- certifi
- click>=4.0,!=8.2.*
- cligj>=0.5
- numpy>=1.24
- click-plugins
- numpy>=2
- pyparsing
- rasterio[docs,ipython,plot,s3,test] ; extra == 'all'
- ghp-import ; extra == 'docs'
- numpydoc ; extra == 'docs'
- sphinx ; extra == 'docs'
@ -8404,28 +8428,17 @@ packages:
- ipython>=2.0 ; extra == 'ipython'
- matplotlib ; extra == 'plot'
- boto3>=1.2.4 ; extra == 's3'
- aiohttp ; extra == 'test'
- boto3>=1.2.4 ; extra == 'test'
- fsspec ; extra == 'test'
- hypothesis ; extra == 'test'
- matplotlib ; extra == 'test'
- packaging ; extra == 'test'
- pytest-cov>=2.2.0 ; extra == 'test'
- pytest>=2.8.2 ; extra == 'test'
- requests ; extra == 'test'
- shapely ; extra == 'test'
- fsspec ; extra == 'all'
- sphinx-rtd-theme ; extra == 'all'
- ipython>=2.0 ; extra == 'all'
- packaging ; extra == 'all'
- ghp-import ; extra == 'all'
- boto3>=1.2.4 ; extra == 'all'
- matplotlib ; extra == 'all'
- sphinx-click ; extra == 'all'
- sphinx ; extra == 'all'
- pytest>=2.8.2 ; extra == 'all'
- hypothesis ; extra == 'all'
- shapely ; extra == 'all'
- numpydoc ; extra == 'all'
- pytest-cov>=2.2.0 ; extra == 'all'
requires_python: '>=3.10'
requires_python: '>=3.12'
- pypi: https://files.pythonhosted.org/packages/f2/98/7e6d147fd16a10a5f821db6e25f192265d6ecca3d82957a4fdd592cad49c/ratelim-0.1.6-py2.py3-none-any.whl
name: ratelim
version: 0.1.6
@ -9139,10 +9152,10 @@ packages:
- pkg:pypi/terminado?source=hash-mapping
size: 22452
timestamp: 1710262728753
- pypi: https://files.pythonhosted.org/packages/b5/fc/5e2988590ff2e0128eea6446806c904445a44e17256c67141573ea16b5a5/textual-6.11.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/84/38/47fab2a5fad163ca4851f7a20eb2442491cc63bf2756ec4ef161bc1461dd/textual-7.0.1-py3-none-any.whl
name: textual
version: 6.11.0
sha256: 9e663b73ed37123a9b13c16a0c85e09ef917a4cfded97814361ed5cccfa40f89
version: 7.0.1
sha256: f9b7d16fa9b640bfff2a2008bf31e3f2d4429dc85e07a9583be033840ed15174
requires_dist:
- markdown-it-py[linkify]>=2.1.0
- mdit-py-plugins
@ -9429,10 +9442,10 @@ packages:
license: BSD-3-Clause
size: 508347
timestamp: 1765407086135
- pypi: https://files.pythonhosted.org/packages/95/20/92e3083b0e854943015bc8a7866e284ead9efadf9bf6809e6fce3b7ded61/ultraplot-1.66.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/43/6c/b26831b890b37c09882f6406efd31441c8e512bf1efbc967b9d867c5e02b/ultraplot-1.70.0-py3-none-any.whl
name: ultraplot
version: 1.66.0
sha256: 87fecb897ca5c7d54b76ac81e5b8635be45d9c9d42d629469f1d283e6405f9e1
version: 1.70.0
sha256: 2b29d1b1e36bd6cf88458370825cfab2c62b9acab706a2cfa434660d7dc4bf74
requires_dist:
- numpy>=1.26.0
- matplotlib>=3.9,<3.11
@ -9496,10 +9509,10 @@ packages:
- packaging
- narwhals>=1.42
requires_python: '>=3.9'
- pypi: https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/6f/61/dc6f4a38cf1b8699f64c57d7f021ca42c39bfe782d8a6eaefb7e8418e925/vl_convert_python-1.9.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
name: vl-convert-python
version: 1.8.0
sha256: b51264998e8fcc43dbce801484a950cfe6513cdc4c46b20604ef50989855a617
version: 1.9.0
sha256: 849e6773a7e05d58ab215386b1065e7713f4846b9ac6b0d743bb3e1b20337231
requires_python: '>=3.7'
- pypi: https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl
name: watchdog

View file

@ -66,7 +66,7 @@ dependencies = [
"pypalettes>=0.2.1,<0.3",
"ty>=0.0.2,<0.0.3",
"ruff>=0.14.9,<0.15",
"pandas-stubs>=2.3.3.251201,<3",
"pandas-stubs>=2.3.3.251201,<3", "pytest>=9.0.2,<10",
]
[project.scripts]

View file

@ -0,0 +1,195 @@
#!/usr/bin/env python
"""Recalculate test metrics and confusion matrix for existing training results.
This script loads previously trained models and recalculates test metrics
and confusion matrices for training runs that were completed before these
outputs were added to the training pipeline.
"""
import pickle
from pathlib import Path
import cupy as cp
import numpy as np
import toml
import torch
import xarray as xr
from sklearn import set_config
from sklearn.metrics import confusion_matrix
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.paths import RESULTS_DIR
# Enable array_api_dispatch to handle CuPy/NumPy namespace properly
set_config(array_api_dispatch=True)
def recalculate_metrics(results_dir: Path):
"""Recalculate test metrics and confusion matrix for a training result.
Args:
results_dir: Path to the results directory containing the trained model.
"""
print(f"\nProcessing: {results_dir}")
# Load the search settings to get training configuration
settings_file = results_dir / "search_settings.toml"
if not settings_file.exists():
print(" ❌ Missing search_settings.toml, skipping...")
return
with open(settings_file) as f:
config = toml.load(f)
settings = config["settings"]
# Check if metrics already exist
test_metrics_file = results_dir / "test_metrics.toml"
cm_file = results_dir / "confusion_matrix.nc"
# if test_metrics_file.exists() and cm_file.exists():
# print(" ✓ Metrics already exist, skipping...")
# return
# Load the best estimator
best_model_file = results_dir / "best_estimator_model.pkl"
if not best_model_file.exists():
print(" ❌ Missing best_estimator_model.pkl, skipping...")
return
print(f" Loading best estimator from {best_model_file.name}...")
with open(best_model_file, "rb") as f:
best_estimator = pickle.load(f)
# Recreate the dataset ensemble
print(" Recreating training dataset...")
dataset_ensemble = DatasetEnsemble(
grid=settings["grid"],
level=settings["level"],
target=settings["target"],
members=settings.get(
"members",
[
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
],
),
dimension_filters=settings.get("dimension_filters", {}),
variable_filters=settings.get("variable_filters", {}),
filter_target=settings.get("filter_target", False),
add_lonlat=settings.get("add_lonlat", True),
)
task = settings["task"]
model = settings["model"]
device = "torch" if model in ["espa"] else "cuda"
# Create training data
training_data = dataset_ensemble.create_cat_training_dataset(task=task, device=device)
# Prepare test data - match training.py's approach
print(" Preparing test data...")
# For XGBoost with CuPy arrays, convert y_test to CPU (same as training.py)
y_test = (
training_data.y.test.get()
if model == "xgboost" and isinstance(training_data.y.test, cp.ndarray)
else training_data.y.test
)
# Compute predictions on the test set (use original device data)
print(" Computing predictions on test set...")
y_pred = best_estimator.predict(training_data.X.test)
# Use torch
y_pred = torch.as_tensor(y_pred, device="cuda")
y_test = torch.as_tensor(y_test, device="cuda")
# Compute metrics manually to avoid device issues
print(" Computing test metrics...")
from sklearn.metrics import (
accuracy_score,
f1_score,
jaccard_score,
precision_score,
recall_score,
)
test_metrics = {}
if task == "binary":
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
test_metrics["recall"] = float(recall_score(y_test, y_pred))
test_metrics["precision"] = float(precision_score(y_test, y_pred))
test_metrics["f1"] = float(f1_score(y_test, y_pred))
test_metrics["jaccard"] = float(jaccard_score(y_test, y_pred))
else:
test_metrics["accuracy"] = float(accuracy_score(y_test, y_pred))
test_metrics["f1_macro"] = float(f1_score(y_test, y_pred, average="macro"))
test_metrics["f1_weighted"] = float(f1_score(y_test, y_pred, average="weighted"))
test_metrics["precision_macro"] = float(precision_score(y_test, y_pred, average="macro", zero_division=0))
test_metrics["precision_weighted"] = float(precision_score(y_test, y_pred, average="weighted", zero_division=0))
test_metrics["recall_macro"] = float(recall_score(y_test, y_pred, average="macro"))
test_metrics["jaccard_micro"] = float(jaccard_score(y_test, y_pred, average="micro"))
test_metrics["jaccard_macro"] = float(jaccard_score(y_test, y_pred, average="macro"))
test_metrics["jaccard_weighted"] = float(jaccard_score(y_test, y_pred, average="weighted"))
# Get confusion matrix
print(" Computing confusion matrix...")
labels = list(range(len(training_data.y.labels)))
labels = torch.as_tensor(np.array(labels), device="cuda")
print(" Device of y_test:", getattr(training_data.y.test, "device", "cpu"))
print(" Device of y_pred:", getattr(y_pred, "device", "cpu"))
print(" Device of labels:", getattr(labels, "device", "cpu"))
cm = confusion_matrix(y_test, y_pred, labels=labels)
cm = cm.cpu().numpy()
labels = labels.cpu().numpy()
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
cm_xr = xr.DataArray(
cm,
dims=["true_label", "predicted_label"],
coords={"true_label": label_names, "predicted_label": label_names},
name="confusion_matrix",
)
# Store the test metrics
if not test_metrics_file.exists():
print(f" Storing test metrics to {test_metrics_file.name}...")
with open(test_metrics_file, "w") as f:
toml.dump({"test_metrics": test_metrics}, f)
else:
print(" ✓ Test metrics already exist")
# Store the confusion matrix
if True:
# if not cm_file.exists():
print(f" Storing confusion matrix to {cm_file.name}...")
cm_xr.to_netcdf(cm_file, engine="h5netcdf")
else:
print(" ✓ Confusion matrix already exists")
print(" ✓ Done!")
def main():
"""Find all training results and recalculate metrics for those missing them."""
print("Searching for training results directories...")
# Find all results directories
results_dirs = sorted([d for d in RESULTS_DIR.glob("*") if d.is_dir()])
print(f"Found {len(results_dirs)} results directories.\n")
for results_dir in results_dirs:
recalculate_metrics(results_dir)
# try:
# except Exception as e:
# print(f" ❌ Error processing {results_dir.name}: {e}")
# continue
print("\n✅ All done!")
if __name__ == "__main__":
main()

58
scripts/rechunk_zarr.py Normal file
View file

@ -0,0 +1,58 @@
import xarray as xr
import zarr
from rich import print
import dask.distributed as dd
from entropice.utils.paths import get_era5_stores
import entropice.utils.codecs
def print_info(daily_raw = None, show_vars: bool = True):
if daily_raw is None:
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
print("=== Daily INFO ===")
print(f" Dims: {daily_raw.sizes}")
numchunks = 1
chunksizes = {}
approxchunksize = 4 # 4 Bytes = float32
for d, cs in daily_raw.chunksizes.items():
numchunks *= len(cs)
chunksizes[d] = max(cs)
approxchunksize *= max(cs)
approxchunksize /= 10e6 # MB
print(f" Chunks: {chunksizes} (~{approxchunksize:.2f}MB) => {numchunks} total")
print(f" Encoding: {daily_raw.encoding}")
if show_vars:
print(" Variables:")
for var in daily_raw.data_vars:
da = daily_raw[var]
print(f" {var} Encoding:")
print(da.encoding)
print("")
def rechunk():
daily_store = get_era5_stores("daily")
daily_raw = xr.open_zarr(daily_store, consolidated=False)
print_info(daily_raw, False)
daily_raw = daily_raw.chunk({
"time": 120,
"latitude": -1, # Should be 337,
"longitude": -1 # Should be 3600
})
print_info(daily_raw, False)
encoding = entropice.utils.codecs.from_ds(daily_raw)
daily_store_rechunked = daily_store.with_stem(f"{daily_store.stem}_rechunked")
daily_raw.to_zarr(daily_store_rechunked, mode="w", encoding=encoding, consolidated=False)
if __name__ == "__main__":
with (
dd.LocalCluster(n_workers=1, threads_per_worker=10, memory_limit="100GB") as cluster,
dd.Client(cluster) as client,
):
print(client)
print(client.dashboard_link)
rechunk()
print("Done.")

View file

@ -0,0 +1,144 @@
#!/usr/bin/env python
"""Rerun inference for training results that are missing predicted probabilities.
This script searches through training result directories and identifies those that have
a trained model but are missing inference results. It then loads the model and dataset
configuration, reruns inference, and saves the results.
"""
import pickle
from pathlib import Path
import toml
from rich.console import Console
from rich.progress import track
from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.inference import predict_proba
from entropice.utils.paths import RESULTS_DIR
console = Console()
def find_incomplete_trainings() -> list[Path]:
"""Find training result directories missing inference results.
Returns:
list[Path]: List of directories with trained models but missing predictions.
"""
incomplete = []
if not RESULTS_DIR.exists():
console.print(f"[yellow]Results directory not found: {RESULTS_DIR}[/yellow]")
return incomplete
# Search for all training result directories
for result_dir in RESULTS_DIR.glob("*_cv*"):
if not result_dir.is_dir():
continue
model_file = result_dir / "best_estimator_model.pkl"
settings_file = result_dir / "search_settings.toml"
predictions_file = result_dir / "predicted_probabilities.parquet"
# Check if model and settings exist but predictions are missing
if model_file.exists() and settings_file.exists() and not predictions_file.exists():
incomplete.append(result_dir)
return incomplete
def rerun_inference(result_dir: Path) -> bool:
"""Rerun inference for a training result directory.
Args:
result_dir (Path): Path to the training result directory.
Returns:
bool: True if successful, False otherwise.
"""
try:
console.print(f"\n[cyan]Processing: {result_dir.name}[/cyan]")
# Load settings
settings_file = result_dir / "search_settings.toml"
with open(settings_file) as f:
settings_data = toml.load(f)
settings = settings_data["settings"]
# Reconstruct DatasetEnsemble from settings
ensemble = DatasetEnsemble(
grid=settings["grid"],
level=settings["level"],
target=settings["target"],
members=settings["members"],
dimension_filters=settings.get("dimension_filters", {}),
variable_filters=settings.get("variable_filters", {}),
filter_target=settings.get("filter_target", False),
add_lonlat=settings.get("add_lonlat", True),
)
# Load trained model
model_file = result_dir / "best_estimator_model.pkl"
with open(model_file, "rb") as f:
clf = pickle.load(f)
console.print("[green]✓[/green] Loaded model and settings")
# Get class labels
classes = settings["classes"]
# Run inference
console.print("[yellow]Running inference...[/yellow]")
preds = predict_proba(ensemble, clf=clf, classes=classes)
# Save predictions
preds_file = result_dir / "predicted_probabilities.parquet"
preds.to_parquet(preds_file)
console.print(f"[green]✓[/green] Saved {len(preds)} predictions to {preds_file.name}")
return True
except Exception as e:
console.print(f"[red]✗ Error processing {result_dir.name}: {e}[/red]")
import traceback
console.print(f"[red]{traceback.format_exc()}[/red]")
return False
def main():
"""Rerun missing inferences for incomplete training results."""
console.print("[bold blue]Searching for incomplete training results...[/bold blue]")
incomplete_dirs = find_incomplete_trainings()
if not incomplete_dirs:
console.print("[green]No incomplete trainings found. All trainings have predictions![/green]")
return
console.print(f"[yellow]Found {len(incomplete_dirs)} training(s) missing predictions:[/yellow]")
for d in incomplete_dirs:
console.print(f"{d.name}")
console.print(f"\n[bold]Processing {len(incomplete_dirs)} training result(s)...[/bold]\n")
successful = 0
failed = 0
for result_dir in track(incomplete_dirs, description="Rerunning inference"):
if rerun_inference(result_dir):
successful += 1
else:
failed += 1
console.print("\n[bold]Summary:[/bold]")
console.print(f" [green]Successful: {successful}[/green]")
console.print(f" [red]Failed: {failed}[/red]")
if __name__ == "__main__":
main()

View file

@ -10,9 +10,11 @@ import pandas as pd
import pydeck as pdk
import streamlit as st
from entropice.dashboard.utils.class_ordering import get_ordered_classes
from entropice.dashboard.utils.colors import get_cmap, get_palette
from entropice.dashboard.utils.geometry import fix_hex_geometry
from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.training import TrainingSettings
def render_performance_summary(results: pd.DataFrame, refit_metric: str):
@ -125,7 +127,7 @@ def render_performance_summary(results: pd.DataFrame, refit_metric: str):
)
def render_parameter_distributions(results: pd.DataFrame, settings: dict | None = None):
def render_parameter_distributions(results: pd.DataFrame, settings: TrainingSettings | None = None):
"""Render histograms of parameter distributions explored.
Args:
@ -1152,15 +1154,18 @@ def render_top_configurations(results: pd.DataFrame, metric: str, top_n: int = 1
@st.fragment
def render_confusion_matrix_map(result_path: Path, settings: dict):
"""Render 3D pydeck map showing confusion matrix results (TP, FP, TN, FN).
def render_confusion_matrix_map(result_path: Path, settings: TrainingSettings):
"""Render 3D pydeck map showing prediction results.
Uses true labels for elevation (height) and different shades of red for incorrect predictions
based on the predicted class.
Args:
result_path: Path to the training result directory.
settings: Settings dictionary containing grid, level, task, and target information.
"""
st.subheader("🗺️ Confusion Matrix Spatial Distribution")
st.subheader("🗺️ Prediction Results Map")
# Load predicted probabilities
preds_file = result_path / "predicted_probabilities.parquet"
@ -1190,62 +1195,41 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
st.error(f"Error loading training data: {e}")
return
# Get the labeled cells (those with true labels)
labeled_cells = training_data.dataset[training_data.dataset.index.isin(training_data.y.binned.index)]
# Get all cells from the complete dataset (not just test split)
# Use the full dataset which includes both train and test splits
all_cells = training_data.dataset.copy()
# Merge predictions with true labels
# Reset index to avoid ambiguity between index and column
labeled_gdf = labeled_cells.copy()
labeled_gdf = labeled_gdf.reset_index().rename(columns={"index": "cell_id"})
labeled_gdf["true_class"] = training_data.y.binned.loc[labeled_cells.index].to_numpy()
labeled_gdf = all_cells.reset_index().rename(columns={"index": "cell_id"})
labeled_gdf["true_class"] = training_data.y.binned.loc[all_cells.index].to_numpy()
# Merge with predictions - ensure we keep GeoDataFrame type
merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="inner")
# Merge with predictions - use left join to keep all cells
merged_df = labeled_gdf.merge(preds_gdf[["cell_id", "predicted_class"]], on="cell_id", how="left")
merged = gpd.GeoDataFrame(merged_df, geometry="geometry", crs=labeled_gdf.crs)
# Mark which cells have predictions (test split) vs not (training split)
merged["in_test_split"] = merged["predicted_class"].notna()
# For cells without predictions (training split), use true class as predicted class for visualization
merged["predicted_class"] = merged["predicted_class"].fillna(merged["true_class"])
if len(merged) == 0:
st.warning("No matching predictions found for labeled cells.")
return
# Determine confusion matrix category
def get_confusion_category(row):
true_label = row["true_class"]
pred_label = row["predicted_class"]
# Mark correct vs incorrect predictions (only meaningful for test split)
merged["is_correct"] = merged["true_class"] == merged["predicted_class"]
if task == "binary":
# For binary classification
if true_label == "RTS" and pred_label == "RTS":
return "True Positive"
elif true_label == "RTS" and pred_label == "No-RTS":
return "False Negative"
elif true_label == "No-RTS" and pred_label == "RTS":
return "False Positive"
else: # true_label == "No-RTS" and pred_label == "No-RTS"
return "True Negative"
else:
# For multiclass (count/density)
if true_label == pred_label:
return "Correct"
else:
return "Incorrect"
merged["confusion_category"] = merged.apply(get_confusion_category, axis=1)
# Get ordered class labels for the task
ordered_classes = get_ordered_classes(task)
# Create controls
col1, col2 = st.columns([3, 1])
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
# Filter by confusion category
if task == "binary":
categories = [
"All",
"True Positive",
"False Positive",
"True Negative",
"False Negative",
]
else:
categories = ["All", "Correct", "Incorrect"]
# Filter by prediction correctness and split
categories = ["All", "Test Split Only", "Training Split Only", "Correct (Test)", "Incorrect (Test)"]
selected_category = st.selectbox(
"Filter by Category",
@ -1263,10 +1247,26 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
key="confusion_map_opacity",
)
with col3:
line_width = st.slider(
"Line Width",
min_value=0.5,
max_value=3.0,
value=1.0,
step=0.5,
key="confusion_map_line_width",
)
# Filter data if needed
if selected_category != "All":
display_gdf = merged[merged["confusion_category"] == selected_category].copy()
else:
if selected_category == "Test Split Only":
display_gdf = merged[merged["in_test_split"]].copy()
elif selected_category == "Training Split Only":
display_gdf = merged[~merged["in_test_split"]].copy()
elif selected_category == "Correct (Test)":
display_gdf = merged[merged["is_correct"] & merged["in_test_split"]].copy()
elif selected_category == "Incorrect (Test)":
display_gdf = merged[~merged["is_correct"] & merged["in_test_split"]].copy()
else: # "All"
display_gdf = merged.copy()
if len(display_gdf) == 0:
@ -1280,49 +1280,72 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
if grid == "hex":
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Assign colors based on confusion category
if task == "binary":
color_map = {
"True Positive": [46, 204, 113], # Green
"False Positive": [231, 76, 60], # Red
"True Negative": [52, 152, 219], # Blue
"False Negative": [241, 196, 15], # Yellow
}
# Get red material colormap for incorrect predictions
red_cmap = get_cmap("red_predictions") # Use red_material palette
n_classes = len(ordered_classes)
# Assign colors based on correctness
def get_color(row):
if row["is_correct"]:
# Green for correct predictions
return [46, 204, 113]
else:
color_map = {
"Correct": [46, 204, 113], # Green
"Incorrect": [231, 76, 60], # Red
}
display_gdf_wgs84["fill_color"] = display_gdf_wgs84["confusion_category"].map(color_map)
# Add elevation based on confusion category (higher for errors)
if task == "binary":
elevation_map = {
"True Positive": 0.8,
"False Positive": 1.0,
"True Negative": 0.3,
"False Negative": 1.0,
}
# Different shades of red for each predicted class (ordered)
pred_class = row["predicted_class"]
if pred_class in ordered_classes:
class_idx = ordered_classes.index(pred_class)
# Sample from red colormap based on class index
color_value = red_cmap(class_idx / max(n_classes - 1, 1))
return [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)]
else:
elevation_map = {
"Correct": 0.5,
"Incorrect": 1.0,
}
# Fallback red if class not found
return [231, 76, 60]
display_gdf_wgs84["elevation"] = display_gdf_wgs84["confusion_category"].map(elevation_map)
display_gdf_wgs84["fill_color"] = display_gdf_wgs84.apply(get_color, axis=1)
# Add line color based on split: blue for test split, orange for training split
def get_line_color(row):
if row["in_test_split"]:
return [52, 152, 219] # Blue for test split
else:
return [230, 126, 34] # Orange for training split
display_gdf_wgs84["line_color"] = display_gdf_wgs84.apply(get_line_color, axis=1)
# Add elevation based on TRUE label (not predicted)
# Map each true class to a height based on its position in the ordered list
def get_elevation(row):
true_class = row["true_class"]
if true_class in ordered_classes:
class_idx = ordered_classes.index(true_class)
# Normalize to 0-1 range based on class position
return (class_idx + 1) / n_classes
else:
return 0.5 # Default elevation
display_gdf_wgs84["elevation"] = display_gdf_wgs84.apply(get_elevation, axis=1)
# Convert to GeoJSON format
geojson_data = []
for _, row in display_gdf_wgs84.iterrows():
# Determine split and status for tooltip
split_name = "Test Split" if row["in_test_split"] else "Training Split"
if row["in_test_split"]:
status = "✓ Correct" if row["is_correct"] else "✗ Incorrect"
else:
status = "(No prediction - training data)"
feature = {
"type": "Feature",
"geometry": row["geometry"].__geo_interface__,
"properties": {
"true_class": str(row["true_class"]),
"predicted_class": str(row["predicted_class"]),
"confusion_category": str(row["confusion_category"]),
"predicted_class": str(row["predicted_class"]) if row["in_test_split"] else "N/A",
"is_correct": bool(row["is_correct"]),
"split": split_name,
"status": status,
"fill_color": row["fill_color"],
"line_color": row["line_color"],
"elevation": float(row["elevation"]),
},
}
@ -1338,8 +1361,8 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
extruded=True,
wireframe=False,
get_fill_color="properties.fill_color",
get_line_color=[80, 80, 80],
line_width_min_pixels=0.5,
get_line_color="properties.line_color",
line_width_min_pixels=line_width,
get_elevation="properties.elevation",
elevation_scale=500000,
pickable=True,
@ -1353,9 +1376,10 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
layers=[layer],
initial_view_state=view_state,
tooltip={
"html": "<b>True Label:</b> {true_class}<br/>"
"html": "<b>Split:</b> {split}<br/>"
"<b>True Label:</b> {true_class}<br/>"
"<b>Predicted Label:</b> {predicted_class}<br/>"
"<b>Category:</b> {confusion_category}",
"<b>Status:</b> {status}",
"style": {"backgroundColor": "steelblue", "color": "white"},
},
map_style="https://basemaps.cartocdn.com/gl/dark-matter-gl-style/style.json",
@ -1365,56 +1389,240 @@ def render_confusion_matrix_map(result_path: Path, settings: dict):
st.pydeck_chart(deck)
# Show statistics
col1, col2, col3 = st.columns(3)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Labeled Cells", len(merged))
if task == "binary":
with col2:
tp = len(merged[merged["confusion_category"] == "True Positive"])
fp = len(merged[merged["confusion_category"] == "False Positive"])
tn = len(merged[merged["confusion_category"] == "True Negative"])
fn = len(merged[merged["confusion_category"] == "False Negative"])
accuracy = (tp + tn) / len(merged) if len(merged) > 0 else 0
st.metric("Accuracy", f"{accuracy:.2%}")
test_count = len(merged[merged["in_test_split"]])
st.metric("Test Split", test_count)
with col3:
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
st.metric("F1 Score", f"{f1:.3f}")
train_count = len(merged[~merged["in_test_split"]])
st.metric("Training Split", train_count)
# Show confusion matrix counts
st.caption(f"TP: {tp} | FP: {fp} | TN: {tn} | FN: {fn}")
with col4:
test_cells = merged[merged["in_test_split"]]
if len(test_cells) > 0:
correct = len(test_cells[test_cells["is_correct"]])
accuracy = correct / len(test_cells)
st.metric("Test Accuracy", f"{accuracy:.2%}")
else:
with col2:
correct = len(merged[merged["confusion_category"] == "Correct"])
accuracy = correct / len(merged) if len(merged) > 0 else 0
st.metric("Accuracy", f"{accuracy:.2%}")
with col3:
incorrect = len(merged[merged["confusion_category"] == "Incorrect"])
st.metric("Incorrect", incorrect)
st.metric("Test Accuracy", "N/A")
# Add legend
with st.expander("Legend", expanded=True):
st.markdown("**Confusion Matrix Categories:**")
# Split indicators (border colors)
st.markdown("**Data Split (Border Color):**")
for category, color in color_map.items():
count = len(merged[merged["confusion_category"] == category])
percentage = count / len(merged) * 100 if len(merged) > 0 else 0
test_count = len(merged[merged["in_test_split"]])
train_count = len(merged[~merged["in_test_split"]])
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px;">'
f'<div style="width: 20px; height: 20px; background-color: rgb({color[0]}, {color[1]}, {color[2]}); '
f'<div style="width: 20px; height: 20px; background-color: #333; '
f'border: 2px solid rgb(52, 152, 219); margin-right: 8px; flex-shrink: 0;"></div>'
f"<span><b>Test Split</b> ({test_count} cells, {test_count / len(merged) * 100:.1f}%)</span></div>",
unsafe_allow_html=True,
)
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 12px;">'
f'<div style="width: 20px; height: 20px; background-color: #333; '
f'border: 2px solid rgb(230, 126, 34); margin-right: 8px; flex-shrink: 0;"></div>'
f"<span><b>Training Split</b> ({train_count} cells, {train_count / len(merged) * 100:.1f}%)</span></div>",
unsafe_allow_html=True,
)
st.markdown("---")
st.markdown("**Fill Color (Prediction Results):**")
# Correct predictions (test split only)
test_cells = merged[merged["in_test_split"]]
correct = len(test_cells[test_cells["is_correct"]]) if len(test_cells) > 0 else 0
incorrect = len(test_cells[~test_cells["is_correct"]]) if len(test_cells) > 0 else 0
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 8px;">'
f'<div style="width: 20px; height: 20px; background-color: rgb(46, 204, 113); '
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
f"<span>{category}: {count} ({percentage:.1f}%)</span></div>",
f"<span><b>Correct Predictions (Test)</b> ({correct} cells, {correct / len(test_cells) * 100 if len(test_cells) > 0 else 0:.1f}%)</span></div>",
unsafe_allow_html=True,
)
# Incorrect predictions by predicted class (shades of red)
st.markdown(
f"<b>Incorrect Predictions by Predicted Class (Test)</b> ({incorrect} cells):", unsafe_allow_html=True
)
for class_idx, class_label in enumerate(ordered_classes):
# Get count of incorrect predictions for this predicted class (test split only)
count = len(test_cells[(~test_cells["is_correct"]) & (test_cells["predicted_class"] == class_label)])
if count > 0:
# Get color for this predicted class
color_value = red_cmap(class_idx / max(n_classes - 1, 1))
rgb = [int(color_value[0] * 255), int(color_value[1] * 255), int(color_value[2] * 255)]
percentage = count / incorrect * 100 if incorrect > 0 else 0
st.markdown(
f'<div style="display: flex; align-items: center; margin-bottom: 4px; margin-left: 20px;">'
f'<div style="width: 20px; height: 20px; background-color: rgb({rgb[0]}, {rgb[1]}, {rgb[2]}); '
f'margin-right: 8px; border: 1px solid #ccc; flex-shrink: 0;"></div>'
f"<span>Predicted as <i>{class_label}</i>: {count} ({percentage:.1f}%)</span></div>",
unsafe_allow_html=True,
)
# Note about training split
st.markdown(
f'<div style="margin-top: 8px; font-style: italic; color: #888;">'
f"Note: Training split cells ({train_count}) are shown with their true labels (green fill) "
f"since predictions are only available for the test split.</div>",
unsafe_allow_html=True,
)
st.markdown("---")
st.markdown("**Elevation (3D):**")
st.markdown("Height represents prediction confidence: Errors are elevated higher than correct predictions.")
# Show elevation mapping for each true class
st.markdown("Height represents the <b>true label</b>:", unsafe_allow_html=True)
for class_idx, class_label in enumerate(ordered_classes):
elevation_value = (class_idx + 1) / n_classes
height_km = elevation_value * 500 # Since elevation_scale is 500000
st.markdown(
f'<div style="margin-left: 20px; margin-bottom: 2px;">• <i>{class_label}</i>: {height_km:.0f} km</div>',
unsafe_allow_html=True,
)
st.info("💡 Rotate the map by holding Ctrl/Cmd and dragging.")
def render_confusion_matrix_heatmap(confusion_matrix: "xr.DataArray", task: str):
"""Render confusion matrix as an interactive heatmap.
Args:
confusion_matrix: xarray DataArray with dimensions (true_label, predicted_label).
task: Task type ('binary' or 'multiclass').
"""
import plotly.express as px
# Convert to DataFrame for plotting
cm_df = confusion_matrix.to_pandas()
# Get labels (convert numeric labels to semantic labels if possible)
true_labels = confusion_matrix.coords["true_label"].values
pred_labels = confusion_matrix.coords["predicted_label"].values
# For binary classification, map 0/1 to No-RTS/RTS
if task == "binary":
label_map = {0: "No-RTS", 1: "RTS"}
true_labels_str = [label_map.get(int(label), str(label)) for label in true_labels]
pred_labels_str = [label_map.get(int(label), str(label)) for label in pred_labels]
else:
# For multiclass, use numeric labels as is
true_labels_str = [str(label) for label in true_labels]
pred_labels_str = [str(label) for label in pred_labels]
# Rename DataFrame indices and columns for display
cm_df.index = true_labels_str
cm_df.columns = pred_labels_str
# Store raw counts for annotations
cm_counts = cm_df.copy()
# Normalize by row (true label) to get percentages
cm_normalized = cm_df.div(cm_df.sum(axis=1), axis=0)
# Create custom text annotations showing both percentage and count
text_annotations = []
for i, true_label in enumerate(true_labels_str):
row_annotations = []
for j, pred_label in enumerate(pred_labels_str):
count = int(cm_counts.iloc[i, j])
percentage = cm_normalized.iloc[i, j] * 100
row_annotations.append(f"{percentage:.1f}%<br>({count:,})")
text_annotations.append(row_annotations)
# Create heatmap with normalized values
fig = px.imshow(
cm_normalized,
labels=dict(x="Predicted Label", y="True Label", color="Proportion"),
x=pred_labels_str,
y=true_labels_str,
color_continuous_scale="Blues",
aspect="auto",
zmin=0,
zmax=1,
)
# Update with custom annotations
fig.update_traces(
text=text_annotations,
texttemplate="%{text}",
textfont={"size": 12},
)
# Update layout for better readability
fig.update_layout(
title="Confusion Matrix (Normalized by True Label)",
xaxis_title="Predicted Label",
yaxis_title="True Label",
height=500,
)
# Update colorbar to show percentage
fig.update_coloraxes(
colorbar=dict(
title="Proportion",
tickformat=".0%",
)
)
st.plotly_chart(fig, width="stretch")
st.caption(
"📊 Values show **row-normalized percentages** (percentage of each true class predicted as each label). "
"Raw counts shown in parentheses."
)
# Calculate and display metrics from confusion matrix
col1, col2, col3 = st.columns(3)
total_samples = int(cm_df.values.sum())
correct_predictions = int(np.trace(cm_df.values))
accuracy = correct_predictions / total_samples if total_samples > 0 else 0
with col1:
st.metric("Total Samples", f"{total_samples:,}")
with col2:
st.metric("Correct Predictions", f"{correct_predictions:,}")
with col3:
st.metric("Accuracy", f"{accuracy:.2%}")
# Add detailed breakdown for binary classification
if task == "binary":
st.markdown("#### Binary Classification Metrics")
# Extract TP, TN, FP, FN from confusion matrix
# Assuming 0=No-RTS (negative), 1=RTS (positive)
tn = int(cm_df.iloc[0, 0])
fp = int(cm_df.iloc[0, 1])
fn = int(cm_df.iloc[1, 0])
tp = int(cm_df.iloc[1, 1])
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("True Positives (TP)", f"{tp:,}")
with col2:
st.metric("True Negatives (TN)", f"{tn:,}")
with col3:
st.metric("False Positives (FP)", f"{fp:,}")
with col4:
st.metric("False Negatives (FN)", f"{fn:,}")

View file

@ -6,6 +6,7 @@ import plotly.graph_objects as go
import pydeck as pdk
import streamlit as st
from entropice.dashboard.utils.class_ordering import get_ordered_classes, sort_class_series
from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.geometry import fix_hex_geometry
from entropice.dashboard.utils.loaders import TrainingResult
@ -64,8 +65,9 @@ def render_class_distribution_histogram(predictions_gdf: gpd.GeoDataFrame, task:
"""
st.subheader("📊 Predicted Class Distribution")
# Get class counts
class_counts = predictions_gdf["predicted_class"].value_counts().sort_index()
# Get class counts and order them properly
class_counts = predictions_gdf["predicted_class"].value_counts()
class_counts = sort_class_series(class_counts, task)
# Get colors based on task
categories = class_counts.index.tolist()
@ -121,7 +123,7 @@ def render_spatial_distribution_stats(predictions_gdf: gpd.GeoDataFrame):
st.subheader("🌍 Spatial Coverage")
# Calculate spatial extent
bounds = predictions_gdf.total_bounds
bounds = predictions_gdf.to_crs("EPSG:4326").total_bounds
col1, col2, col3, col4 = st.columns(4)
@ -193,8 +195,8 @@ def render_inference_map(result: TrainingResult):
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
# Get unique classes for filtering
all_classes = sorted(preds_gdf["predicted_class"].unique())
# Get unique classes for filtering (properly ordered)
all_classes = get_ordered_classes(task, preds_gdf["predicted_class"].unique().tolist())
filter_options = ["All Classes", *all_classes]
selected_filter = st.selectbox(
@ -240,11 +242,13 @@ def render_inference_map(result: TrainingResult):
if grid == "hex":
display_gdf_wgs84["geometry"] = display_gdf_wgs84["geometry"].apply(fix_hex_geometry)
# Assign colors based on predicted class
colors_palette = get_palette(task, len(all_classes))
# Assign colors based on predicted class (using canonical ordering)
# Get all possible classes for this task to ensure consistent colors
canonical_classes = get_ordered_classes(task)
colors_palette = get_palette(task, len(canonical_classes))
# Create color mapping for all classes
color_map = {cls: colors_palette[i] for i, cls in enumerate(all_classes)}
# Create color mapping for canonical classes
color_map = {cls: colors_palette[i] for i, cls in enumerate(canonical_classes)}
# Convert hex colors to RGB
def hex_to_rgb(hex_color):
@ -349,8 +353,9 @@ def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str):
"""
st.subheader("🔍 Class Comparison")
# Get class distribution
# Get class distribution and order properly
class_counts = predictions_gdf["predicted_class"].value_counts()
class_counts = sort_class_series(class_counts, task)
if len(class_counts) < 2:
st.info("Need at least 2 classes for comparison.")
@ -362,7 +367,13 @@ def render_class_comparison(predictions_gdf: gpd.GeoDataFrame, task: str):
with col1:
st.markdown("**Class Proportions")
colors = get_palette(task, len(class_counts))
# Get colors matching canonical order
canonical_classes = get_ordered_classes(task)
all_colors = get_palette(task, len(canonical_classes))
color_map = {cls: all_colors[i] for i, cls in enumerate(canonical_classes)}
# Extract colors for available classes in proper order
colors = [color_map[cls] for cls in class_counts.index]
fig = go.Figure(
data=[

View file

@ -271,7 +271,7 @@ def plot_era5_heatmap(era5_array: xr.DataArray) -> alt.Chart:
return chart
def plot_era5_time_heatmap(era5_array: xr.DataArray) -> alt.Chart:
def plot_era5_time_heatmap(era5_array: xr.DataArray) -> alt.Chart | None:
"""Create a heatmap showing ERA5 feature weights by variable and year (averaging over season).
This is specifically for seasonal/shoulder data to show temporal trends.

View file

@ -48,50 +48,78 @@ def create_sample_count_heatmap(
def create_sample_count_bar_chart(
sample_df: pd.DataFrame,
task_colors: list[str] | None = None,
target_color_maps: dict[str, list[str]] | None = None,
) -> go.Figure:
"""Create bar chart showing sample counts by grid, target, and task.
Args:
sample_df: DataFrame with columns: Grid, Target, Task, Samples (Coverage).
task_colors: Optional color palette for tasks. If None, uses default Plotly colors.
target_color_maps: Optional dictionary mapping target names ("rts", "mllabels") to color palettes.
If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the bar chart visualization.
"""
fig = px.bar(
sample_df,
x="Grid",
y="Samples (Coverage)",
color="Task",
facet_col="Target",
barmode="group",
title="Sample Counts by Grid Configuration and Target Dataset",
labels={
"Grid": "Grid Configuration",
"Samples (Coverage)": "Number of Samples",
},
color_discrete_sequence=task_colors,
height=500,
# Create subplots manually to have better control over colors
from plotly.subplots import make_subplots
targets = sorted(sample_df["Target"].unique())
tasks = sorted(sample_df["Task"].unique())
fig = make_subplots(
rows=1,
cols=len(targets),
subplot_titles=[f"Target: {target}" for target in targets],
shared_yaxes=True,
)
for col_idx, target in enumerate(targets, 1):
target_data = sample_df[sample_df["Target"] == target]
# Get color palette for this target
colors = target_color_maps.get(target, None) if target_color_maps else None
for task_idx, task in enumerate(tasks):
task_data = target_data[target_data["Task"] == task]
color = colors[task_idx] if colors and task_idx < len(colors) else None
# Create a unique legendgroup per task so colors are consistent
fig.add_trace(
go.Bar(
x=task_data["Grid"],
y=task_data["Samples (Coverage)"],
name=task,
marker_color=color,
legendgroup=task, # Group by task name
showlegend=(col_idx == 1), # Only show legend for first subplot
),
row=1,
col=col_idx,
)
fig.update_layout(
title_text="Training Sample Counts by Grid Configuration and Target Dataset",
barmode="group",
height=500,
showlegend=True,
)
# Update facet labels to be cleaner
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_xaxes(tickangle=-45)
fig.update_yaxes(title_text="Number of Samples", row=1, col=1)
return fig
def create_feature_count_stacked_bar(
breakdown_df: pd.DataFrame,
source_colors: list[str] | None = None,
source_color_map: dict[str, str] | None = None,
) -> go.Figure:
"""Create stacked bar chart showing feature counts by data source.
Args:
breakdown_df: DataFrame with columns: Grid, Data Source, Number of Features.
source_colors: Optional color palette for data sources. If None, uses default Plotly colors.
source_color_map: Optional dictionary mapping data source names to specific colors.
If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the stacked bar chart visualization.
@ -103,12 +131,12 @@ def create_feature_count_stacked_bar(
y="Number of Features",
color="Data Source",
barmode="stack",
title="Total Features by Data Source Across Grid Configurations",
title="Input Features by Data Source Across Grid Configurations",
labels={
"Grid": "Grid Configuration",
"Number of Features": "Number of Features",
},
color_discrete_sequence=source_colors,
color_discrete_map=source_color_map,
text_auto=False,
)
@ -136,7 +164,7 @@ def create_inference_cells_bar(
x="Grid",
y="Inference Cells",
color="Grid",
title="Inference Cells by Grid Configuration",
title="Spatial Coverage (Grid Cells with Complete Data)",
labels={
"Grid": "Grid Configuration",
"Inference Cells": "Number of Cells",
@ -170,7 +198,7 @@ def create_total_samples_bar(
x="Grid",
y="Total Samples",
color="Grid",
title="Total Samples by Grid Configuration",
title="Training Samples (Binary Task)",
labels={
"Grid": "Grid Configuration",
"Total Samples": "Number of Samples",
@ -188,14 +216,15 @@ def create_total_samples_bar(
def create_feature_breakdown_donut(
grid_data: pd.DataFrame,
grid_config: str,
source_colors: list[str] | None = None,
source_color_map: dict[str, str] | None = None,
) -> go.Figure:
"""Create donut chart for feature breakdown by data source for a specific grid.
Args:
grid_data: DataFrame with columns: Data Source, Number of Features.
grid_config: Grid configuration name for the title.
source_colors: Optional color palette for data sources. If None, uses default Plotly colors.
source_color_map: Optional dictionary mapping data source names to specific colors.
If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the donut chart visualization.
@ -207,7 +236,8 @@ def create_feature_breakdown_donut(
values="Number of Features",
title=grid_config,
hole=0.4,
color_discrete_sequence=source_colors,
color_discrete_map=source_color_map,
color="Data Source",
)
fig.update_traces(textposition="inside", textinfo="percent")
@ -218,13 +248,14 @@ def create_feature_breakdown_donut(
def create_feature_distribution_pie(
breakdown_df: pd.DataFrame,
source_colors: list[str] | None = None,
source_color_map: dict[str, str] | None = None,
) -> go.Figure:
"""Create pie chart for feature distribution by data source.
Args:
breakdown_df: DataFrame with columns: Data Source, Number of Features.
source_colors: Optional color palette for data sources. If None, uses default Plotly colors.
source_color_map: Optional dictionary mapping data source names to specific colors.
If None, uses default Plotly colors.
Returns:
Plotly Figure object containing the pie chart visualization.
@ -236,7 +267,8 @@ def create_feature_distribution_pie(
values="Number of Features",
title="Feature Distribution by Data Source",
hole=0.4,
color_discrete_sequence=source_colors,
color_discrete_map=source_color_map,
color="Data Source",
)
fig.update_traces(textposition="inside", textinfo="percent+label")

View file

@ -0,0 +1,70 @@
"""Utilities for ordering predicted classes consistently across visualizations.
This module leverages the canonical class labels defined in the ML dataset module
to ensure consistent ordering across all visualizations.
"""
import pandas as pd
from entropice.utils.types import Task
# Canonical orderings imported from the ML pipeline
# Binary labels are defined inline in dataset.py: {False: "No RTS", True: "RTS"}
# Count/Density labels are defined in the bin_values function
BINARY_LABELS = ["No RTS", "RTS"]
COUNT_LABELS = ["None", "Very Few", "Few", "Several", "Many", "Very Many"]
DENSITY_LABELS = ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"]
CLASS_ORDERINGS: dict[Task | str, list[str]] = {
"binary": BINARY_LABELS,
"count": COUNT_LABELS,
"density": DENSITY_LABELS,
}
def get_ordered_classes(task: Task | str, available_classes: list[str] | None = None) -> list[str]:
"""Get properly ordered class labels for a given task.
This uses the same canonical ordering as defined in the ML dataset module,
ensuring consistency between training and inference visualizations.
Args:
task: Task type ('binary', 'count', 'density').
available_classes: Optional list of available classes to filter and order.
If None, returns all canonical classes for the task.
Returns:
List of class labels in proper order.
Examples:
>>> get_ordered_classes("binary")
['No RTS', 'RTS']
>>> get_ordered_classes("count", ["None", "Few", "Several"])
['None', 'Few', 'Several']
"""
canonical_order = CLASS_ORDERINGS[task]
if available_classes is None:
return canonical_order
# Filter canonical order to only include available classes, preserving order
return [cls for cls in canonical_order if cls in available_classes]
def sort_class_series(series: pd.Series, task: Task | str) -> pd.Series:
"""Sort a pandas Series with class labels according to canonical ordering.
Args:
series: Pandas Series with class labels as index.
task: Task type ('binary', 'count', 'density').
Returns:
Sorted Series with classes in canonical order.
"""
available_classes = series.index.tolist()
ordered_classes = get_ordered_classes(task, available_classes)
# Reindex to get proper order
return series.reindex(ordered_classes)

View file

@ -1,13 +1,9 @@
"""Geometry utilities for dashboard visualizations."""
import antimeridian
import streamlit as st
from shapely.geometry import shape
try:
import antimeridian
except ImportError:
antimeridian = None
def fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian.
@ -23,13 +19,9 @@ def fix_hex_geometry(geom):
Fixed geometry object with antimeridian issues resolved.
Note:
If the antimeridian library is not available or an error occurs,
returns the original geometry unchanged.
If an error occurs, returns the original geometry unchanged.
"""
if antimeridian is None:
return geom
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:

View file

@ -35,6 +35,8 @@ class TrainingResult:
path: Path
settings: TrainingSettings
results: pd.DataFrame
metrics: dict[str, float]
confusion_matrix: xr.DataArray
created_at: float
available_metrics: list[str]
@ -44,23 +46,32 @@ class TrainingResult:
result_file = result_path / "search_results.parquet"
preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml"
metrics_file = result_path / "test_metrics.toml"
confusion_matrix_file = result_path / "confusion_matrix.nc"
if not result_file.exists():
raise FileNotFoundError(f"Missing results file in {result_path}")
if not settings_file.exists():
raise FileNotFoundError(f"Missing settings file in {result_path}")
if not preds_file.exists():
raise FileNotFoundError(f"Missing predictions file in {result_path}")
if not metrics_file.exists():
raise FileNotFoundError(f"Missing metrics file in {result_path}")
if not confusion_matrix_file.exists():
raise FileNotFoundError(f"Missing confusion matrix file in {result_path}")
created_at = result_path.stat().st_ctime
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
results = pd.read_parquet(result_file)
metrics = toml.load(metrics_file)["test_metrics"]
confusion_matrix = xr.open_dataarray(confusion_matrix_file, engine="h5netcdf")
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
return cls(
path=result_path,
settings=settings,
results=results,
metrics=metrics,
confusion_matrix=confusion_matrix,
created_at=created_at,
available_metrics=available_metrics,
)
@ -126,7 +137,7 @@ def load_all_training_results() -> list[TrainingResult]:
try:
training_result = TrainingResult.from_path(result_path)
except FileNotFoundError as e:
st.warning(f"Skipping incomplete training result at {result_path}: {e}")
st.warning(f"Skipping incomplete training result: {e}")
continue
training_results.append(training_result)

View file

@ -14,7 +14,6 @@ from entropice.dashboard.plots.inference import (
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
@st.fragment
def render_sidebar_selection(training_results: list[TrainingResult]) -> TrainingResult:
"""Render sidebar for training run selection.

View file

@ -13,8 +13,6 @@ from entropice.dashboard.plots.overview import (
create_feature_distribution_pie,
create_inference_cells_bar,
create_sample_count_bar_chart,
create_sample_count_heatmap,
create_total_samples_bar,
)
from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.loaders import load_all_training_results
@ -32,11 +30,9 @@ from entropice.utils.types import (
def render_sample_count_overview():
"""Render overview of sample counts per task+target+grid+level combination."""
st.subheader("📊 Sample Counts by Configuration")
st.markdown(
"""
This visualization shows the number of available samples for each combination of:
This visualization shows the number of available training samples for each combination of:
- **Task**: binary, count, density
- **Target Dataset**: darts_rts, darts_mllabels
- **Grid System**: hex, healpix
@ -47,54 +43,20 @@ def render_sample_count_overview():
# Get sample count DataFrame from cache
all_stats = load_all_default_dataset_statistics()
sample_df = DatasetStatistics.get_sample_count_df(all_stats)
target_datasets = ["darts_rts", "darts_mllabels"]
# Create tabs for different views
tab1, tab2, tab3 = st.tabs(["📈 Heatmap", "📊 Bar Chart", "📋 Data Table"])
with tab1:
st.markdown("### Sample Counts Heatmap")
st.markdown("Showing counts of samples with coverage")
# Create heatmap for each target dataset
for target in target_datasets:
target_df = sample_df[sample_df["Target"] == target.replace("darts_", "")]
# Pivot for heatmap: Grid x Task
pivot_df = target_df.pivot_table(
index="Grid",
columns="Task",
values="Samples (Coverage)",
aggfunc="mean",
)
# Sort index by grid type and level
sort_order = sample_df[["Grid", "Grid_Level_Sort"]].drop_duplicates().set_index("Grid")
pivot_df = pivot_df.reindex(sort_order.sort_values("Grid_Level_Sort").index)
# Get color palette for sample counts
sample_colors = get_palette(f"sample_counts_{target}", n_colors=10)
# Create and display heatmap
fig = create_sample_count_heatmap(pivot_df, target, colorscale=sample_colors)
st.plotly_chart(fig, width="stretch")
with tab2:
st.markdown("### Sample Counts Bar Chart")
st.markdown("Showing counts of samples with coverage")
# Get color palette for tasks
# Get color palettes for each target dataset
n_tasks = sample_df["Task"].nunique()
task_colors = get_palette("task_types", n_colors=n_tasks)
target_color_maps = {
"rts": get_palette("task_types", n_colors=n_tasks),
"mllabels": get_palette("data_sources", n_colors=n_tasks),
}
# Create and display bar chart
fig = create_sample_count_bar_chart(sample_df, task_colors=task_colors)
st.plotly_chart(fig, width="stretch")
with tab3:
st.markdown("### Detailed Sample Counts")
fig = create_sample_count_bar_chart(sample_df, target_color_maps=target_color_maps)
st.plotly_chart(fig, use_container_width=True)
# Display full table with formatting
st.markdown("#### Detailed Sample Counts")
display_df = sample_df[["Grid", "Target", "Task", "Samples (Coverage)", "Coverage %"]].copy()
# Format numbers with commas
@ -102,13 +64,18 @@ def render_sample_count_overview():
# Format coverage as percentage with 2 decimal places
display_df["Coverage %"] = display_df["Coverage %"].apply(lambda x: f"{x:.2f}%")
st.dataframe(display_df, hide_index=True, width="stretch")
st.dataframe(display_df, hide_index=True, use_container_width=True)
def render_feature_count_comparison():
"""Render static comparison of feature counts across all grid configurations."""
st.markdown("### Feature Count Comparison Across Grid Configurations")
st.markdown("Comparing feature counts for all grid configurations with all data sources enabled")
st.markdown(
"""
Comparing dataset characteristics for all grid configurations with all data sources enabled.
- **Features**: Total number of input features from all data sources
- **Spatial Coverage**: Number of grid cells with complete data coverage
"""
)
# Get data from cache
all_stats = load_all_default_dataset_statistics()
@ -116,87 +83,44 @@ def render_feature_count_comparison():
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
# Create tabs for different comparison views
comp_tab1, comp_tab2, comp_tab3 = st.tabs(["📊 Bar Chart", "📈 Breakdown", "📋 Data Table"])
with comp_tab1:
st.markdown("#### Total Features by Grid Configuration")
# Get color palette for data sources
unique_sources = breakdown_df["Data Source"].unique()
# Get all unique data sources and create color map
unique_sources = sorted(breakdown_df["Data Source"].unique())
n_sources = len(unique_sources)
source_colors = get_palette("data_sources", n_colors=n_sources)
source_color_list = get_palette("data_sources", n_colors=n_sources)
source_color_map = dict(zip(unique_sources, source_color_list))
# Create and display stacked bar chart
fig = create_feature_count_stacked_bar(breakdown_df, source_colors=source_colors)
st.plotly_chart(fig, width="stretch")
fig = create_feature_count_stacked_bar(breakdown_df, source_color_map=source_color_map)
st.plotly_chart(fig, use_container_width=True)
# Add secondary metrics
col1, col2 = st.columns(2)
# Get color palette for grid configs
# Add spatial coverage metric
n_grids = len(comparison_df)
grid_colors = get_palette("grid_configs", n_colors=n_grids)
with col1:
fig_cells = create_inference_cells_bar(comparison_df, grid_colors=grid_colors)
st.plotly_chart(fig_cells, width="stretch")
with col2:
fig_samples = create_total_samples_bar(comparison_df, grid_colors=grid_colors)
st.plotly_chart(fig_samples, width="stretch")
with comp_tab2:
st.markdown("#### Feature Breakdown by Data Source")
st.markdown("Showing percentage contribution of each data source across all grid configurations")
# Get color palette for data sources
unique_sources = breakdown_df["Data Source"].unique()
n_sources = len(unique_sources)
source_colors = get_palette("data_sources", n_colors=n_sources)
# Create donut charts for each grid configuration
# Organize in a grid layout
num_grids = len(comparison_df)
cols_per_row = 3
num_rows = (num_grids + cols_per_row - 1) // cols_per_row
for row_idx in range(num_rows):
cols = st.columns(cols_per_row)
for col_idx in range(cols_per_row):
grid_idx = row_idx * cols_per_row + col_idx
if grid_idx < num_grids:
grid_config = comparison_df.iloc[grid_idx]["Grid"]
grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
with cols[col_idx]:
fig = create_feature_breakdown_donut(grid_data, grid_config, source_colors=source_colors)
st.plotly_chart(fig, width="stretch")
with comp_tab3:
st.markdown("#### Detailed Feature Count Comparison")
st.plotly_chart(fig_cells, use_container_width=True)
# Display full comparison table with formatting
st.markdown("#### Detailed Comparison Table")
display_df = comparison_df[
[
"Grid",
"Total Features",
"Data Sources",
"Inference Cells",
"Total Samples",
]
].copy()
# Format numbers with commas
for col in ["Total Features", "Inference Cells", "Total Samples"]:
for col in ["Total Features", "Inference Cells"]:
display_df[col] = display_df[col].apply(lambda x: f"{x:,}")
st.dataframe(display_df, hide_index=True, width="stretch")
st.dataframe(display_df, hide_index=True, use_container_width=True)
@st.fragment
def render_feature_count_explorer():
"""Render interactive detailed configuration explorer using fragments."""
st.markdown("### Detailed Configuration Explorer")
st.markdown("Select specific grid configuration and data sources for detailed statistics")
# Grid selection
@ -250,8 +174,6 @@ def render_feature_count_explorer():
# Show results if at least one member is selected
if selected_members:
st.markdown("---")
# Get statistics from cache (already loaded)
grid_stats = all_stats[selected_grid_config.id]
@ -301,12 +223,14 @@ def render_feature_count_explorer():
breakdown_df = pd.DataFrame(breakdown_data)
# Get color palette for data sources
n_sources = len(breakdown_df)
source_colors = get_palette("data_sources", n_colors=n_sources)
# Get all unique data sources and create color map
unique_sources = sorted(breakdown_df["Data Source"].unique())
n_sources = len(unique_sources)
source_color_list = get_palette("data_sources", n_colors=n_sources)
source_color_map = dict(zip(unique_sources, source_color_list))
# Create and display pie chart
fig = create_feature_distribution_pie(breakdown_df, source_colors=source_colors)
fig = create_feature_distribution_pie(breakdown_df, source_color_map=source_color_map)
st.plotly_chart(fig, width="stretch")
# Show detailed table
@ -363,38 +287,82 @@ def render_feature_count_explorer():
st.info("👆 Select at least one data source to see feature statistics")
def render_feature_count_section():
"""Render the feature count section with comparison and explorer."""
st.subheader("🔢 Feature Counts by Dataset Configuration")
st.markdown(
"""
This visualization shows the total number of features that would be generated
for different combinations of data sources and grid configurations.
"""
)
# Static comparison across all grids
render_feature_count_comparison()
st.divider()
# Interactive explorer for detailed analysis
render_feature_count_explorer()
def render_dataset_analysis():
"""Render the dataset analysis section with sample and feature counts."""
st.header("📈 Dataset Analysis")
# Create tabs for the two different analyses
analysis_tabs = st.tabs(["📊 Sample Counts", "🔢 Feature Counts"])
# Create tabs for different analysis views
analysis_tabs = st.tabs(
[
"📊 Training Samples",
"📈 Dataset Characteristics",
"🔍 Feature Breakdown",
"⚙️ Configuration Explorer",
]
)
with analysis_tabs[0]:
st.subheader("Training Samples by Configuration")
render_sample_count_overview()
with analysis_tabs[1]:
render_feature_count_section()
st.subheader("Dataset Characteristics Across Grid Configurations")
render_feature_count_comparison()
with analysis_tabs[2]:
st.subheader("Feature Breakdown by Data Source")
# Get data from cache
all_stats = load_all_default_dataset_statistics()
comparison_df = DatasetStatistics.get_feature_count_df(all_stats)
breakdown_df = DatasetStatistics.get_feature_breakdown_df(all_stats)
breakdown_df = breakdown_df.sort_values("Grid_Level_Sort")
# Get all unique data sources and create color map
unique_sources = sorted(breakdown_df["Data Source"].unique())
n_sources = len(unique_sources)
source_color_list = get_palette("data_sources", n_colors=n_sources)
source_color_map = dict(zip(unique_sources, source_color_list))
st.markdown("Showing percentage contribution of each data source across all grid configurations")
# Sparse Resolution girds
for res in ["sparse", "low", "medium"]:
cols = st.columns(2)
with cols[0]:
grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "hex"]
for gc in grid_configs_res:
grid_display = gc.display_name
grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
with cols[1]:
grid_configs_res = [gc for gc in grid_configs if gc.res == res and gc.grid == "healpix"]
for gc in grid_configs_res:
grid_display = gc.display_name
grid_data = breakdown_df[breakdown_df["Grid"] == grid_display]
fig = create_feature_breakdown_donut(grid_data, grid_display, source_color_map=source_color_map)
st.plotly_chart(fig, width="stretch", key=f"donut_{grid_display}")
# Create donut charts for each grid configuration
# num_grids = len(comparison_df)
# cols_per_row = 3
# num_rows = (num_grids + cols_per_row - 1) // cols_per_row
# for row_idx in range(num_rows):
# cols = st.columns(cols_per_row)
# for col_idx in range(cols_per_row):
# grid_idx = row_idx * cols_per_row + col_idx
# if grid_idx < num_grids:
# grid_config = comparison_df.iloc[grid_idx]["Grid"]
# grid_data = breakdown_df[breakdown_df["Grid"] == grid_config]
# with cols[col_idx]:
# fig = create_feature_breakdown_donut(grid_data, grid_config, source_color_map=source_color_map)
# st.plotly_chart(fig, use_container_width=True)
with analysis_tabs[3]:
st.subheader("Interactive Configuration Explorer")
render_feature_count_explorer()
def render_training_results_summary(training_results):

View file

@ -1,10 +1,13 @@
"""Training Results Analysis page: Analysis of training results and model performance."""
from typing import cast
import streamlit as st
from stopuhr import stopwatch
from entropice.dashboard.plots.hyperparameter_analysis import (
render_binned_parameter_space,
render_confusion_matrix_heatmap,
render_confusion_matrix_map,
render_espa_binned_parameter_space,
render_multi_metric_comparison,
@ -14,12 +17,12 @@ from entropice.dashboard.plots.hyperparameter_analysis import (
render_top_configurations,
)
from entropice.dashboard.utils.formatters import format_metric_name
from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.dashboard.utils.loaders import TrainingResult, load_all_training_results
from entropice.dashboard.utils.stats import CVResultsStatistics
from entropice.utils.types import GridConfig
@st.fragment
def render_analysis_settings_sidebar(training_results):
def render_analysis_settings_sidebar(training_results: list[TrainingResult]) -> tuple[TrainingResult, str, str, int]:
"""Render sidebar for training run and analysis settings selection.
Args:
@ -42,7 +45,7 @@ def render_analysis_settings_sidebar(training_results):
key="training_run_select",
)
selected_result = training_options[selected_name]
selected_result = cast(TrainingResult, training_options[selected_name])
st.divider()
@ -52,22 +55,7 @@ def render_analysis_settings_sidebar(training_results):
available_metrics = selected_result.available_metrics
# Try to get refit metric from settings
refit_metric = selected_result.settings.refit_metric if hasattr(selected_result.settings, "refit_metric") else None
if not refit_metric or refit_metric not in available_metrics:
# Infer from task or use first available metric
task = selected_result.settings.task
if task == "binary" and "f1" in available_metrics:
refit_metric = "f1"
elif "f1_weighted" in available_metrics:
refit_metric = "f1_weighted"
elif "accuracy" in available_metrics:
refit_metric = "accuracy"
elif available_metrics:
refit_metric = available_metrics[0]
else:
st.error("No metrics found in results.")
return None, None, None, None
refit_metric = "f1" if selected_result.settings.task == "binary" else "f1_weighted"
if refit_metric in available_metrics:
default_metric_idx = available_metrics.index(refit_metric)
@ -97,7 +85,7 @@ def render_analysis_settings_sidebar(training_results):
return selected_result, selected_metric, refit_metric, top_n
def render_run_information(selected_result, refit_metric):
def render_run_information(selected_result: TrainingResult, refit_metric):
"""Render training run configuration overview.
Args:
@ -107,23 +95,86 @@ def render_run_information(selected_result, refit_metric):
"""
st.header("📋 Run Information")
col1, col2, col3, col4, col5, col6 = st.columns(6)
grid_config = GridConfig.from_grid_level(f"{selected_result.settings.grid}{selected_result.settings.level}") # ty:ignore[invalid-argument-type]
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Task", selected_result.settings.task.capitalize())
with col2:
st.metric("Grid", selected_result.settings.grid.capitalize())
st.metric("Target", selected_result.settings.target.capitalize())
with col3:
st.metric("Level", selected_result.settings.level)
st.metric("Grid", grid_config.display_name)
with col4:
st.metric("Model", selected_result.settings.model.upper())
with col5:
st.metric("Trials", len(selected_result.results))
with col6:
st.metric("CV Splits", selected_result.settings.cv_splits)
st.caption(f"**Refit Metric:** {format_metric_name(refit_metric)}")
def render_test_metrics_section(selected_result: TrainingResult):
"""Render test metrics overview showing final model performance.
Args:
selected_result: The selected TrainingResult object.
"""
st.header("🎯 Test Set Performance")
st.caption("Performance metrics on the held-out test set (best model from hyperparameter search)")
test_metrics = selected_result.metrics
if not test_metrics:
st.warning("No test metrics available for this training run.")
return
# Display metrics in columns based on task type
task = selected_result.settings.task
if task == "binary":
# Binary classification metrics
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Accuracy", f"{test_metrics.get('accuracy', 0):.4f}")
with col2:
st.metric("F1 Score", f"{test_metrics.get('f1', 0):.4f}")
with col3:
st.metric("Precision", f"{test_metrics.get('precision', 0):.4f}")
with col4:
st.metric("Recall", f"{test_metrics.get('recall', 0):.4f}")
with col5:
st.metric("Jaccard", f"{test_metrics.get('jaccard', 0):.4f}")
else:
# Multiclass metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Accuracy", f"{test_metrics.get('accuracy', 0):.4f}")
with col2:
st.metric("F1 (Macro)", f"{test_metrics.get('f1_macro', 0):.4f}")
with col3:
st.metric("F1 (Weighted)", f"{test_metrics.get('f1_weighted', 0):.4f}")
col4, col5, col6 = st.columns(3)
with col4:
st.metric("Precision (Macro)", f"{test_metrics.get('precision_macro', 0):.4f}")
with col5:
st.metric("Precision (Weighted)", f"{test_metrics.get('precision_weighted', 0):.4f}")
with col6:
st.metric("Recall (Macro)", f"{test_metrics.get('recall_macro', 0):.4f}")
col7, col8, col9 = st.columns(3)
with col7:
st.metric("Jaccard (Micro)", f"{test_metrics.get('jaccard_micro', 0):.4f}")
with col8:
st.metric("Jaccard (Macro)", f"{test_metrics.get('jaccard_macro', 0):.4f}")
with col9:
st.metric("Jaccard (Weighted)", f"{test_metrics.get('jaccard_weighted', 0):.4f}")
def render_cv_statistics_section(selected_result, selected_metric):
"""Render cross-validation statistics for selected metric.
@ -133,6 +184,7 @@ def render_cv_statistics_section(selected_result, selected_metric):
"""
st.header("📈 Cross-Validation Statistics")
st.caption("Performance during hyperparameter search (averaged across CV folds)")
from entropice.dashboard.utils.stats import CVMetricStatistics
@ -158,6 +210,45 @@ def render_cv_statistics_section(selected_result, selected_metric):
if cv_stats.mean_cv_std is not None:
st.info(f"**Mean CV Std:** {cv_stats.mean_cv_std:.4f} - Average standard deviation across CV folds")
# Compare with test metric if available
if selected_metric in selected_result.metrics:
test_score = selected_result.metrics[selected_metric]
st.divider()
st.subheader("CV vs Test Performance")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Best CV Score", f"{cv_stats.best_score:.4f}")
with col2:
st.metric("Test Score", f"{test_score:.4f}")
with col3:
delta = test_score - cv_stats.best_score
delta_pct = (delta / cv_stats.best_score * 100) if cv_stats.best_score != 0 else 0
st.metric("Difference", f"{delta:+.4f}", delta=f"{delta_pct:+.2f}%")
if abs(delta) > cv_stats.std_score:
st.warning(
"⚠️ Test performance differs significantly from CV performance. "
"This may indicate overfitting or data distribution mismatch."
)
def render_confusion_matrix_section(selected_result: TrainingResult):
"""Render confusion matrix visualization and analysis.
Args:
selected_result: The selected TrainingResult object.
"""
st.header("🎲 Confusion Matrix")
st.caption("Detailed breakdown of predictions on the test set")
if selected_result.confusion_matrix is None:
st.warning("No confusion matrix available for this training run.")
return
render_confusion_matrix_heatmap(selected_result.confusion_matrix, selected_result.settings.task)
def render_parameter_space_section(selected_result, selected_metric):
"""Render parameter space analysis section.
@ -292,8 +383,19 @@ def render_training_analysis_page():
st.divider()
# Test Metrics Section
render_test_metrics_section(selected_result)
st.divider()
# Confusion Matrix Section
render_confusion_matrix_section(selected_result)
st.divider()
# Performance Summary Section
st.header("📊 Performance Overview")
st.header("📊 CV Performance Overview")
st.caption("Summary of hyperparameter search results across all configurations")
render_performance_summary(results, refit_metric)
st.divider()

View file

@ -320,6 +320,7 @@ def render_arcticdem_view(ensemble: DatasetEnsemble, arcticdem_ds, targets):
render_arcticdem_map(arcticdem_ds, targets, ensemble.grid)
@st.fragment
def render_era5_view(ensemble: DatasetEnsemble, era5_data: dict[L2SourceDataset, tuple], targets):
"""Render ERA5 climate data analysis.

View file

@ -17,6 +17,7 @@ import json
from collections.abc import Generator
from dataclasses import asdict, dataclass, field
from functools import cached_property
from itertools import product
from typing import Literal, TypedDict
import cupy as cp
@ -30,6 +31,7 @@ import xarray as xr
from rich import pretty, traceback
from sklearn import set_config
from sklearn.model_selection import train_test_split
from stopuhr import stopwatch
import entropice.utils.paths
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task
@ -295,6 +297,30 @@ class DatasetEnsemble:
temporal: Literal["yearly", "seasonal", "shoulder"],
) -> pd.DataFrame:
era5 = self._read_member("ERA5-" + temporal, targets)
if len(era5["cell_ids"]) == 0:
# No data for these cells - create empty DataFrame with expected columns
# Use the Dataset metadata to determine column structure
variables = list(era5.data_vars)
times = era5.coords["time"].to_numpy()
time_df = pd.DataFrame({"time": times})
time_df.index = pd.DatetimeIndex(times)
tempus = _get_era5_tempus(time_df, temporal)
unique_tempus = tempus.unique()
if "aggregations" in era5.dims:
aggs_list = era5.coords["aggregations"].to_numpy()
expected_cols = [
f"era5_{var}_{t}_{agg}" for var, t, agg in product(variables, unique_tempus, aggs_list)
]
else:
expected_cols = [f"era5_{var}_{t}" for var, t in product(variables, unique_tempus)]
return pd.DataFrame(
index=targets["cell_id"].values,
columns=expected_cols,
dtype=float,
)
era5_df = era5.to_dataframe()
era5_df["t"] = _get_era5_tempus(era5_df, temporal)
if "aggregations" not in era5.dims:
@ -303,20 +329,50 @@ class DatasetEnsemble:
else:
era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"])
era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN
era5_df = era5_df.reindex(targets["cell_id"].values, fill_value=np.nan)
return era5_df
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
if len(embeddings["cell_ids"]) == 0:
# No data for these cells - create empty DataFrame with expected columns
# Use the Dataset metadata to determine column structure
years = embeddings.coords["year"].to_numpy()
aggs = embeddings.coords["agg"].to_numpy()
bands = embeddings.coords["band"].to_numpy()
expected_cols = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in product(years, aggs, bands)]
return pd.DataFrame(
index=targets["cell_id"].values,
columns=expected_cols,
dtype=float,
)
embeddings_df = embeddings.to_dataframe(name="value")
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN
embeddings_df = embeddings_df.reindex(targets["cell_id"].values, fill_value=np.nan)
return embeddings_df
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
arcticdem = self._read_member("ArcticDEM", targets)
if len(arcticdem["cell_ids"]) == 0:
# No data for these cells - create empty DataFrame with expected columns
# Use the Dataset metadata to determine column structure
variables = list(arcticdem.data_vars)
aggs = arcticdem.coords["aggregations"].to_numpy()
expected_cols = [f"arcticdem_{var}_{agg}" for var, agg in product(variables, aggs)]
return pd.DataFrame(
index=targets["cell_id"].values,
columns=expected_cols,
dtype=float,
)
arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations")
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN
arcticdem_df = arcticdem_df.reindex(targets["cell_id"].values, fill_value=np.nan)
return arcticdem_df
def get_stats(self) -> DatasetStats:
@ -374,16 +430,19 @@ class DatasetEnsemble:
# n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
if cache_mode == "r" and cache_file.exists():
with stopwatch("Loading dataset from cache"):
dataset = gpd.read_parquet(cache_file)
print(
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
f" and {len(dataset.columns)} features."
)
return dataset
with stopwatch("Reading target"):
targets = self._read_target()
if filter_target_col is not None:
targets = targets.loc[targets[filter_target_col]]
print(f"Read and filtered target dataset. ({len(targets)} samples)")
with stopwatch("Preparing member datasets"):
member_dfs = []
for member in self.members:
if member.startswith("ERA5"):
@ -395,11 +454,14 @@ class DatasetEnsemble:
member_dfs.append(self._prep_arcticdem(targets))
else:
raise NotImplementedError(f"Member {member} not implemented.")
print("Prepared all member datasets. Joining...")
with stopwatch("Joining datasets"):
dataset = targets.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
print("Joining complete.")
if cache_mode in ["o", "r"]:
with stopwatch("Saving dataset to cache"):
dataset.to_parquet(cache_file)
print(f"Saved dataset to cache at {cache_file}.")
return dataset

View file

@ -46,6 +46,11 @@ def predict_proba(
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
X_batch = batch.drop(columns=cols_to_drop).dropna()
# Skip empty batches (all rows had NaN values)
if len(X_batch) == 0:
continue
cell_ids = X_batch.index.to_numpy()
cell_geoms = batch.loc[X_batch.index, "geometry"].to_numpy()
X_batch = X_batch.to_numpy(dtype="float64")

View file

@ -2,12 +2,14 @@
import pickle
from dataclasses import asdict, dataclass
from functools import partial
import cupy as cp
import cyclopts
import pandas as pd
import toml
import xarray as xr
from array_api_compat import get_namespace
from cuml.ensemble import RandomForestClassifier
from cuml.neighbors import KNeighborsClassifier
from entropy import ESPAClassifier
@ -15,6 +17,14 @@ from rich import pretty, traceback
from scipy.stats import loguniform, randint
from scipy.stats._distn_infrastructure import rv_continuous_frozen, rv_discrete_frozen
from sklearn import set_config
from sklearn.metrics import (
accuracy_score,
confusion_matrix,
f1_score,
jaccard_score,
precision_score,
recall_score,
)
from sklearn.model_selection import KFold, RandomizedSearchCV
from stopuhr import stopwatch
from xgboost.sklearn import XGBClassifier
@ -49,6 +59,26 @@ _metrics = {
}
# Compute other metrics - using predictions directly instead of re-predicting for each metric
# Use functools.partial for cleaner metric definitions with non-default parameters
_metric_functions = {
"accuracy": accuracy_score,
"recall": recall_score,
"precision": precision_score,
"f1": f1_score,
"jaccard": jaccard_score,
"recall_macro": partial(recall_score, average="macro"),
"recall_weighted": partial(recall_score, average="weighted"),
"precision_macro": partial(precision_score, average="macro"),
"precision_weighted": partial(precision_score, average="weighted"),
"f1_macro": partial(f1_score, average="macro"),
"f1_weighted": partial(f1_score, average="weighted"),
"jaccard_micro": partial(jaccard_score, average="micro"),
"jaccard_macro": partial(jaccard_score, average="macro"),
"jaccard_weighted": partial(jaccard_score, average="weighted"),
}
@cyclopts.Parameter("*")
@dataclass(frozen=True, kw_only=True)
class CVSettings:
@ -200,6 +230,26 @@ def random_cv(
print(f"Accuracy of the best parameters using the inner CV of the random search: {search.best_score_:.3f}")
print(f"Accuracy on test set: {test_accuracy:.3f}")
# Compute predictions on the test set
y_pred = best_estimator.predict(training_data.X.test)
labels = list(range(len(training_data.y.labels)))
xp = get_namespace(y_test)
y_test = xp.as_array(y_test)
y_pred = xp.as_array(y_pred)
labels = xp.as_array(labels)
test_metrics = {metric: _metric_functions[metric](y_test, y_pred) for metric in metrics}
# Get a confusion matrix
cm = confusion_matrix(y_test, y_pred, labels=labels)
label_names = [training_data.y.labels[i] for i in range(len(training_data.y.labels))]
cm = xr.DataArray(
xp.as_numpy(cm),
dims=["true_label", "predicted_label"],
coords={"true_label": label_names, "predicted_label": label_names},
name="confusion_matrix",
)
results_dir = get_cv_results_dir(
"random_search",
grid=dataset_ensemble.grid,
@ -239,6 +289,17 @@ def random_cv(
print(f"Storing CV results to {results_file}")
results.to_parquet(results_file)
# Store the test metrics
test_metrics_file = results_dir / "test_metrics.toml"
print(f"Storing test metrics to {test_metrics_file}")
with open(test_metrics_file, "w") as f:
toml.dump({"test_metrics": test_metrics}, f)
# Store the confusion matrix
cm_file = results_dir / "confusion_matrix.nc"
print(f"Storing confusion matrix to {cm_file}")
cm.to_netcdf(cm_file, engine="h5netcdf")
# Get the inner state of the best estimator
features = training_data.X.data.columns.tolist()

View file

@ -30,6 +30,7 @@ class GridConfig:
level: int
id: GridLevel
display_name: str
res: Literal["sparse", "low", "medium"]
sort_key: str
@classmethod
@ -46,11 +47,24 @@ class GridConfig:
display_name = f"{grid.capitalize()}-{level}"
resmap: dict[str, Literal["sparse", "low", "medium"]] = {
"hex3": "sparse",
"hex4": "sparse",
"hex5": "low",
"hex6": "medium",
"healpix6": "sparse",
"healpix7": "sparse",
"healpix8": "low",
"healpix9": "low",
"healpix10": "medium",
}
return cls(
grid=grid,
level=level,
id=grid_level,
display_name=display_name,
res=resmap[grid_level],
sort_key=f"{grid}_{level:02d}",
)

View file

@ -0,0 +1,33 @@
"""Debug script to check what _prep_arcticdem returns for a batch."""
from entropice.ml.dataset import DatasetEnsemble
ensemble = DatasetEnsemble(
grid="healpix",
level=10,
target="darts_mllabels",
members=["ArcticDEM"],
add_lonlat=True,
filter_target=False,
)
# Get targets
targets = ensemble._read_target()
print(f"Total targets: {len(targets)}")
# Get first batch of targets
batch_targets = targets.iloc[:100]
print(f"\nBatch targets: {len(batch_targets)}")
print(f"Cell IDs in batch: {batch_targets['cell_id'].values[:5]}")
# Try to prep ArcticDEM for this batch
print("\n" + "=" * 80)
print("Calling _prep_arcticdem...")
print("=" * 80)
arcticdem_df = ensemble._prep_arcticdem(batch_targets)
print(f"\nArcticDEM DataFrame shape: {arcticdem_df.shape}")
print(f"ArcticDEM DataFrame index: {arcticdem_df.index[:5].tolist() if len(arcticdem_df) > 0 else 'EMPTY'}")
print(
f"ArcticDEM DataFrame columns ({len(arcticdem_df.columns)}): {arcticdem_df.columns[:10].tolist() if len(arcticdem_df.columns) > 0 else 'NO COLUMNS'}"
)
print(f"Number of non-NaN rows: {arcticdem_df.notna().any(axis=1).sum()}")

View file

@ -0,0 +1,72 @@
"""Debug script to identify feature mismatch between training and inference."""
from entropice.ml.dataset import DatasetEnsemble
# Test with level 6 (the actual level used in production)
ensemble = DatasetEnsemble(
grid="healpix",
level=10,
target="darts_mllabels",
members=[
"AlphaEarth",
"ArcticDEM",
"ERA5-yearly",
"ERA5-seasonal",
"ERA5-shoulder",
],
add_lonlat=True,
filter_target=False,
)
print("=" * 80)
print("Creating training dataset...")
print("=" * 80)
training_data = ensemble.create_cat_training_dataset(task="binary", device="cpu")
training_features = set(training_data.X.data.columns)
print(f"\nTraining dataset created with {len(training_features)} features")
print(f"Sample features: {sorted(list(training_features))[:10]}")
print("\n" + "=" * 80)
print("Creating inference batch...")
print("=" * 80)
batch_generator = ensemble.create_batches(batch_size=100, cache_mode="n")
batch = next(batch_generator, None)
# for batch in batch_generator:
if batch is None:
print("ERROR: No batch created!")
else:
print(f"\nBatch created with {len(batch.columns)} columns")
print(f"Batch columns: {sorted(batch.columns)[:15]}")
# Simulate the column dropping in predict_proba (inference.py)
cols_to_drop = ["geometry"]
if ensemble.target == "darts_mllabels":
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
print(f"\nColumns to drop: {cols_to_drop}")
inference_batch = batch.drop(columns=cols_to_drop)
inference_features = set(inference_batch.columns)
print(f"\nInference batch after dropping has {len(inference_features)} features")
print(f"Sample features: {sorted(list(inference_features))[:10]}")
print("\n" + "=" * 80)
print("COMPARISON")
print("=" * 80)
print(f"Training features: {len(training_features)}")
print(f"Inference features: {len(inference_features)}")
if training_features == inference_features:
print("\n✅ SUCCESS: Features match perfectly!")
else:
print("\n❌ MISMATCH DETECTED!")
only_in_training = training_features - inference_features
only_in_inference = inference_features - training_features
if only_in_training:
print(f"\n⚠️ Only in TRAINING ({len(only_in_training)}): {sorted(only_in_training)}")
if only_in_inference:
print(f"\n⚠️ Only in INFERENCE ({len(only_in_inference)}): {sorted(only_in_inference)}")

310
tests/test_dataset.py Normal file
View file

@ -0,0 +1,310 @@
"""Tests for dataset.py module, specifically DatasetEnsemble class."""
import geopandas as gpd
import numpy as np
import pytest
from entropice.ml.dataset import DatasetEnsemble
@pytest.fixture
def sample_ensemble():
"""Create a sample DatasetEnsemble for testing with minimal data."""
return DatasetEnsemble(
grid="hex",
level=3, # Use level 3 for much faster tests
target="darts_rts",
members=["AlphaEarth"], # Use only one member for faster tests
add_lonlat=True,
filter_target="darts_has_coverage", # Filter to reduce dataset size
)
@pytest.fixture
def sample_ensemble_mllabels():
"""Create a sample DatasetEnsemble with mllabels target."""
return DatasetEnsemble(
grid="hex",
level=3, # Use level 3 for much faster tests
target="darts_mllabels",
members=["AlphaEarth"], # Use only one member for faster tests
add_lonlat=True,
filter_target="dartsml_has_coverage", # Filter to reduce dataset size
)
class TestDatasetEnsemble:
"""Test suite for DatasetEnsemble class."""
def test_initialization(self, sample_ensemble):
"""Test that DatasetEnsemble initializes correctly."""
assert sample_ensemble.grid == "hex"
assert sample_ensemble.level == 3
assert sample_ensemble.target == "darts_rts"
assert "AlphaEarth" in sample_ensemble.members
assert sample_ensemble.add_lonlat is True
def test_covcol_property(self, sample_ensemble, sample_ensemble_mllabels):
"""Test that covcol property returns correct column name."""
assert sample_ensemble.covcol == "darts_has_coverage"
assert sample_ensemble_mllabels.covcol == "dartsml_has_coverage"
def test_taskcol_property(self, sample_ensemble, sample_ensemble_mllabels):
"""Test that taskcol returns correct column name for different tasks."""
assert sample_ensemble.taskcol("binary") == "darts_has_rts"
assert sample_ensemble.taskcol("count") == "darts_rts_count"
assert sample_ensemble.taskcol("density") == "darts_rts_density"
assert sample_ensemble_mllabels.taskcol("binary") == "dartsml_has_rts"
assert sample_ensemble_mllabels.taskcol("count") == "dartsml_rts_count"
assert sample_ensemble_mllabels.taskcol("density") == "dartsml_rts_density"
def test_create_returns_geodataframe(self, sample_ensemble):
"""Test that create() returns a GeoDataFrame."""
dataset = sample_ensemble.create(cache_mode="n")
assert isinstance(dataset, gpd.GeoDataFrame)
def test_create_has_expected_columns(self, sample_ensemble):
"""Test that create() returns dataset with expected columns."""
dataset = sample_ensemble.create(cache_mode="n")
# Should have geometry column
assert "geometry" in dataset.columns
# Should have lat/lon if add_lonlat is True
if sample_ensemble.add_lonlat:
assert "lon" in dataset.columns
assert "lat" in dataset.columns
# Should have target columns (darts_*)
assert any(col.startswith("darts_") for col in dataset.columns)
# Should have member data columns
assert len(dataset.columns) > 3 # More than just geometry, lat, lon
def test_create_batches_consistency(self, sample_ensemble):
"""Test that create_batches produces batches with consistent columns."""
batch_size = 100
batches = list(sample_ensemble.create_batches(batch_size=batch_size, cache_mode="n"))
if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)")
# All batches should have the same columns
first_batch_cols = set(batches[0].columns)
for i, batch in enumerate(batches[1:], start=1):
assert set(batch.columns) == first_batch_cols, (
f"Batch {i} has different columns than batch 0. "
f"Difference: {set(batch.columns).symmetric_difference(first_batch_cols)}"
)
def test_create_vs_create_batches_columns(self, sample_ensemble):
"""Test that create() and create_batches() return datasets with the same columns."""
full_dataset = sample_ensemble.create(cache_mode="n")
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)")
# Columns should be identical
full_cols = set(full_dataset.columns)
batch_cols = set(batches[0].columns)
assert full_cols == batch_cols, (
f"Column mismatch between create() and create_batches().\n"
f"Only in create(): {full_cols - batch_cols}\n"
f"Only in create_batches(): {batch_cols - full_cols}"
)
def test_training_dataset_feature_columns(self, sample_ensemble):
"""Test that create_cat_training_dataset creates proper feature columns."""
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
# Get the columns used for model inputs
model_input_cols = set(training_data.X.data.columns)
# These columns should NOT be in model inputs
assert "geometry" not in model_input_cols
assert sample_ensemble.covcol not in model_input_cols
assert sample_ensemble.taskcol("binary") not in model_input_cols
# No darts_* columns should be in model inputs
for col in model_input_cols:
assert not col.startswith("darts_"), f"Column {col} should have been dropped from model inputs"
def test_training_dataset_feature_columns_mllabels(self, sample_ensemble_mllabels):
"""Test feature columns for mllabels target."""
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
model_input_cols = set(training_data.X.data.columns)
# These columns should NOT be in model inputs
assert "geometry" not in model_input_cols
assert sample_ensemble_mllabels.covcol not in model_input_cols
assert sample_ensemble_mllabels.taskcol("binary") not in model_input_cols
# No dartsml_* columns should be in model inputs
for col in model_input_cols:
assert not col.startswith("dartsml_"), f"Column {col} should have been dropped from model inputs"
def test_inference_vs_training_feature_consistency(self, sample_ensemble):
"""Test that inference batches have the same features as training data after column dropping.
This test simulates the workflow in training.py and inference.py to ensure
feature consistency between training and inference.
"""
# Step 1: Create training dataset (as in training.py)
# Use only a small subset by creating just one batch
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
training_feature_cols = set(training_data.X.data.columns)
# Step 2: Create inference batch (as in inference.py)
# Get just the first batch to speed up test
batch_generator = sample_ensemble.create_batches(batch_size=100, cache_mode="n")
batch = next(batch_generator, None)
if batch is None:
pytest.skip("No batches created (dataset might be empty)")
# Simulate the column dropping in predict_proba
cols_to_drop = ["geometry"]
if sample_ensemble.target == "darts_mllabels":
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
inference_batch = batch.drop(columns=cols_to_drop)
inference_feature_cols = set(inference_batch.columns)
# The features should match!
assert training_feature_cols == inference_feature_cols, (
f"Feature mismatch between training and inference!\n"
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
f"Only in inference: {inference_feature_cols - training_feature_cols}\n"
f"Training features ({len(training_feature_cols)}): {sorted(training_feature_cols)}\n"
f"Inference features ({len(inference_feature_cols)}): {sorted(inference_feature_cols)}"
)
def test_inference_vs_training_feature_consistency_mllabels(self, sample_ensemble_mllabels):
"""Test feature consistency for mllabels target."""
training_data = sample_ensemble_mllabels.create_cat_training_dataset(task="binary", device="cpu")
training_feature_cols = set(training_data.X.data.columns)
# Get just the first batch to speed up test
batch_generator = sample_ensemble_mllabels.create_batches(batch_size=100, cache_mode="n")
batch = next(batch_generator, None)
if batch is None:
pytest.skip("No batches created (dataset might be empty)")
# Simulate the column dropping in predict_proba
cols_to_drop = ["geometry"]
if sample_ensemble_mllabels.target == "darts_mllabels":
cols_to_drop += [col for col in batch.columns if col.startswith("dartsml_")]
else:
cols_to_drop += [col for col in batch.columns if col.startswith("darts_")]
inference_batch = batch.drop(columns=cols_to_drop)
inference_feature_cols = set(inference_batch.columns)
assert training_feature_cols == inference_feature_cols, (
f"Feature mismatch between training and inference!\n"
f"Only in training: {training_feature_cols - inference_feature_cols}\n"
f"Only in inference: {inference_feature_cols - training_feature_cols}"
)
def test_all_tasks_feature_consistency(self, sample_ensemble):
"""Test that all task types produce consistent features."""
tasks = ["binary", "count", "density"]
feature_sets = {}
for task in tasks:
training_data = sample_ensemble.create_cat_training_dataset(task=task, device="cpu")
feature_sets[task] = set(training_data.X.data.columns)
# All tasks should have the same features
binary_features = feature_sets["binary"]
for task, features in feature_sets.items():
assert features == binary_features, (
f"Task '{task}' has different features than 'binary'.\n"
f"Difference: {features.symmetric_difference(binary_features)}"
)
def test_training_dataset_shapes(self, sample_ensemble):
"""Test that training dataset has correct shapes."""
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
n_features = len(training_data.X.data.columns)
n_samples_train = training_data.X.train.shape[0]
n_samples_test = training_data.X.test.shape[0]
# Check X shapes
assert training_data.X.train.shape == (n_samples_train, n_features)
assert training_data.X.test.shape == (n_samples_test, n_features)
# Check y shapes
assert training_data.y.train.shape == (n_samples_train,)
assert training_data.y.test.shape == (n_samples_test,)
# Check that train + test = total samples
assert len(training_data.dataset) == n_samples_train + n_samples_test
def test_no_nan_in_training_features(self, sample_ensemble):
"""Test that training features don't contain NaN values."""
training_data = sample_ensemble.create_cat_training_dataset(task="binary", device="cpu")
# Convert to numpy for checking (handles both numpy and cupy arrays)
X_train = np.asarray(training_data.X.train)
X_test = np.asarray(training_data.X.test)
assert not np.isnan(X_train).any(), "Training features contain NaN values"
assert not np.isnan(X_test).any(), "Test features contain NaN values"
def test_batch_coverage(self, sample_ensemble):
"""Test that batches cover all data without duplication."""
full_dataset = sample_ensemble.create(cache_mode="n")
batches = list(sample_ensemble.create_batches(batch_size=100, cache_mode="n"))
if len(batches) == 0:
pytest.skip("No batches created (dataset might be empty)")
# Collect all cell_ids from batches
batch_cell_ids = set()
for batch in batches:
batch_ids = set(batch.index)
# Check for duplicates across batches
overlap = batch_cell_ids.intersection(batch_ids)
assert len(overlap) == 0, f"Found {len(overlap)} duplicate cell_ids across batches"
batch_cell_ids.update(batch_ids)
# Check that all cell_ids from full dataset are in batches
full_cell_ids = set(full_dataset.index)
assert batch_cell_ids == full_cell_ids, (
f"Batch coverage mismatch.\n"
f"Missing from batches: {full_cell_ids - batch_cell_ids}\n"
f"Extra in batches: {batch_cell_ids - full_cell_ids}"
)
class TestDatasetEnsembleEdgeCases:
"""Test edge cases and error handling."""
def test_invalid_task_raises_error(self, sample_ensemble):
"""Test that invalid task raises ValueError."""
with pytest.raises(ValueError, match="Invalid task"):
sample_ensemble.create_cat_training_dataset(task="invalid", device="cpu") # type: ignore
def test_stats_method(self, sample_ensemble):
"""Test that get_stats returns expected structure."""
stats = sample_ensemble.get_stats()
assert "target" in stats
assert "num_target_samples" in stats
assert "members" in stats
assert "total_features" in stats
# Check that members dict contains info for each member
for member in sample_ensemble.members:
assert member in stats["members"]
assert "variables" in stats["members"][member]
assert "num_features" in stats["members"][member]