![]() * Add a Split node to burn-import * Register operation in to_burn * Create Split config function * Dimension inference for split outputs * Remove unnecessary f-strings from squeeze * ONNX file for Split and scipt that generates it * Add Split node to name function in Node impl * Update supported onnx ops list * Add codegen test * Include split onnx model in build * Split values should be taken from inputs, make sure only num_outputs or split is provided * Codegen should make a Vec<Tensor<B, D>> * Fix up split codegen * Remove panic if split is not provided * Add basic split test * Keep the number of output tensor sizes fixed * Clippy fixes * Update supported ops list * Cleanup build errors * Update onnx test now that return is tuple of static size * Potential workaround to constant int node * Change num_outputs to split_size in SplitConfig to follow burn implementation * Intraconvert from ONNX graph node to SplitConfig properly * Revert attempt at sidestepping constant int node issue * Copy override logic from @jameshiew * Fill in placeholder docstrings * Remove initializer helpers * Move code for generating uninitialized tensors into burn-import --------- Co-authored-by: James Hiew <james@hiew.net> |
||
---|---|---|
.. | ||
src | ||
tests | ||
.gitignore | ||
.python-version | ||
Cargo.toml | ||
README.md | ||
build.rs | ||
pyproject.toml | ||
requirements-dev.lock | ||
requirements.lock |
README.md
ONNX Tests
This crate contains ONNX models that are utilized in testing the conversion of ONNX to Burn source
code through the burn-import
crate. The tests are designed as end-to-end tests, ensuring that ONNX
models are accurately converted into Burn source code. Of utmost importance is verifying that the
converted Burn source code compiles without errors and produces the same output as the original ONNX
model.
Here is the directory structure of this crate:
tests/<model>
: This directory contains the ONNX model and the Python script to generate it.tests/<model>/<model>.onnx
: The ONNX model is generated by the script.tests/<model>/<model>.py
: This is the Python script responsible for generating the ONNX model using PyTorch.tests/onnx_tests.rs
: This is the main test file, where all the tests are contained.build.rs
: This build script generates the ONNX models and is executed bycargo test
before running the actual tests.
Setting up your python environment
With rye
You can use rye
to set up a Python environment with the necessary dependencies. To do so, cd into the onnx-tests
directory and run rye sync
. Assuming you are in the top-level burn
directory, you can run the following command:
cd crates/burn-import/onnx-tests
rye sync # or rye sync -f
This will create a .venv in the onnx-tests
directory.
You need to install onnx==1.15.0
and torch==2.1.1
in your python environment to add a new test
Adding new tests
Here are the steps to add a new test:
- Add your Python script to the
tests/<model>
directory. Refer to existing scripts for examples. - Run your Python script to generate the ONNX model and inspect the output of the model with the test data. Use the inputs and outputs in your test.
- Make sure the ONNX output contains the desired operators by verifying with the
Netron app. Sometimes PyTorch will optimize the model and
remove operators that are not necessary for the model to run. If this happens, you can disable
optimization by setting
torch.onnx.export(..., do_constant_folding=False)
. - Add an entry to the
build.rs
file to account for the generation of the new ONNX model. - Add an entry to
include_models!
intests/onnx_tests.rs
to include the new ONNX model in the tests. - Include a test in
tests/onnx_tests.rs
to test the new ONNX model. - Run
cargo test
to ensure your test passes.