Commit Graph

90 Commits

Author SHA1 Message Date
Nathaniel Simard a567c6e888
Fusion mix precision (#2247) 2024-09-05 10:53:26 -04:00
Guillaume Lagrange cc214d366c
Nonzero should return an empty vec for zero tensors (#2212)
* Nonzero should return an empty vec for zero tensors

* Add nonzero empty test

* Add missing import

---------

Co-authored-by: Nathaniel Simard <nathaniel.simard.42@gmail.com>
2024-09-03 09:00:58 -04:00
Guillaume Lagrange 09a15e7e15
Avoid 0 denominator in interpolate frac (#2224) 2024-09-01 16:37:32 -04:00
Adrian Müller 2f4c5ac0a1
Feat: Allow onnx-import expand op with non-const shapes (#2189)
* Feat: Allow onnx-import expand op with non-const shapes

* Generalize ONNX Expand across IntElem
2024-08-29 13:15:44 -04:00
AlteredOxide 0292967000
Feature/codegen gather indices greater than rank 1 (#2199)
* implemented muli-dim index for GatherNode

The `NodeCodegen` impl for `GatherNode` now performs gather in complete
accordance with the ONNX Gather spec.
- a `gather` function was added to the gather.rs file
- `gather()` is now called within the codegen instead of `tensor.select()`
- a test with two test cases have been added
    - test axes 0 and 1
    - both use 2D index tensors

* add gather_onnx to numeric api

Added int and float implementations of gather to the burn-tensor numeric
api:
- named the methods `gather_onnx` to not be confused with the current
  `gather`
- these implementations follow the `Gather` ONNX spec

Updated the gather*.py variants and their onnx outputs

* modified files didn't end up in last commit

* tests passing for onnx gather

The implementation of gather for the ONNX `Gather` spec is tentatively
complete:
- py test models are updated
- onnx_tests are modified and passing: `gather`, `gather_scalar`, and
  `gather_shape`
- node/gather tests are passing

NOTE: The two additional tests in crates/burn-import/src/burn/node/gather.rs that test
the actual functionality of gather are likely to be deleted, since they
are redundant to the tests in
crates/burn-import/onnx-tests/tests/onnx_tests.rs.

* inlined onnx gather within codegen

* rm gather_onnx from public api; rm unnecessary tests

* add comments to gather py models

* some codegen changes; formatting to appease run-checks

- Some necessary changes and improvements to the codegen inlined code
  after translating from public api (removed in previous commit).
- Changed some formatting that run-checks complained about.

* simplify gather codegen; include 1d and 2d onnx tests

Modified the `Gather` codegen per requested changes:
- combined match statements on index
- remove use of `alloc::vec::Vec`
- use map -> collect instead of procedural
- include a 1d index gather onnx test
- remove superflous tests

* delete unused gather.onnx
2024-08-28 07:51:19 -04:00
syl20bnr 8e78106680 Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
Dilshod Tadjibaev 75a2850047
Add closeness tensor report (#2184)
* Add closeness tensor report

* Add documentation section

* Fix for no-std

* Fix epsilon formatting

* Update report.rs

* Fix import references

* Fix doc test

* Use colored crate instead of passing codes

* Small refactor to use iter directly

* Move colored dep to std

* Add missing

* Fix missing epsilon
2024-08-22 10:19:27 -05:00
Dilshod Tadjibaev d4a1d2026d
Fix equal/not-equal infinity numbers for burn-ndarray (#2166) 2024-08-15 12:33:54 -05:00
Guillaume Lagrange d2699022df
Add 0-dim tensor checks for creation ops and validate TensorData shape w/ num values (#2137) 2024-08-15 09:54:22 -04:00
Dilshod Tadjibaev 1c681f46ec
Precision option for tensor display (#2139) 2024-08-08 15:01:42 -05:00
Dilshod Tadjibaev 6b61ad5a61
Fix #2091 bug (in-place after expand) (#2114) 2024-08-07 17:37:20 -04:00
Genna Wingert a01004dd4a
Add Hard sigmoid activation function (#2112)
* Add Hard Sigmoid activation function

* Add ONNX import conversion for HardSigmoid

* Update supported operators list

* Update book

* Make test comparison approximate to eliminate precision issues

* Add burn-candle test

* Fix name in E2E test generator
2024-08-07 13:01:42 -05:00
Dilshod Tadjibaev cd848b1c94
Add is_nan and contains_nan tensor ops (#2088)
* Add is_nan and contains_nan tensor ops

* Enable nan test for burn-candle

* Disabling tests due to #2089
2024-08-06 12:16:12 -05:00
Guillaume Lagrange 27d42cdaad
Fix aggregation results slice (#2110)
* Fix aggregation results slice

* View aggregation results as 1d tensor instead
2024-08-06 12:02:11 -05:00
Dilshod Tadjibaev 80cc6d4eb5
Fix bug: Filling tensor containing f32::NEG_INFINITY will result in NaN for burn-ndarray (#2095)
* Fix #2094 bug

* Fix typo

* Fix mask broadcasting
2024-08-05 08:32:31 -04:00
mepatrick73 f7639bd35a
Repeat operation (#2090)
* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error
2024-08-02 20:33:47 -04:00
Dilshod Tadjibaev 297173124f
Add 1d and 2d modules for interpolate with scaling (also fix ONNX Resize op) (#2081)
* Add interpolate module

* Update module.md

* Add interpolate 1d and 2d modules

* Consolidated InterpolateMode for 1d and 2d

* Remove CoordinateTransformationMode

* Add 1d tests for interpolate

* Refactor and fixes of ONNX Resize OP

* Fix clippy

* Fix docs

* Fix no_std
2024-07-31 12:08:26 -05:00
Nathaniel Simard 096ec13c48
Chore/update/cubecl (#2067) 2024-07-28 12:15:02 -04:00
Guillaume Lagrange 64a2f12827
Extend [min, max] range to ensure zero-point (#2055) 2024-07-24 09:55:11 -04:00
johnhuichen 4a3fc9d4a0
Implement ONNX Pad Operator (#2007)
* Implement ONNX pad

* ONNX pad arguments fix

pad now requires 2 or more arguments
if the third argument is not given, it will default to 0

* fixing bug in input len fix

* change panic comment

Change panic comment from needing two inputs. This comes from the fact that the ONNX spec requires two necessary inputs but could have more two more optional argument.

---------

Co-authored-by: JC <you@example.com>
Co-authored-by: mepatrick73 <pameu17@ulaval.ca>
2024-07-23 13:50:20 -04:00
Nathaniel Simard 19cd67a9e2
Migration/cubecl (#2041) 2024-07-22 11:08:40 -04:00
Guillaume Lagrange 0d5025edbb
Refactor tensor quantization for q_* ops (#2025)
* Move QuantizationScheme to burn-tensor

* Refactor QuantizedTensorPrimitive to include the quantization strategy

* Fix QFloat tensor data display

* Refactor quantization methods to use scheme and qparams (on backend device)

* Fix clippy

* Fix fmt

* Add qtensor primitive tests
2024-07-19 10:39:50 -04:00
Dilshod Tadjibaev ed8a91d48a
Update slice documentation (#2024) 2024-07-16 11:59:02 -05:00
Guillaume Lagrange 3afff434bd
Module weight quantization (#2000)
* Add q_into_data and q_reshape

* Fix tch quantize f16 and q_into_data

* Convert to actual dtype/kind in dequantize

* Add module quantization and q_from_data

* Fix clippy

* Add documentation

* Handle deserialize data conversion

* Fix typo

* Add calibration tests

* Fix clippy precision

* Add QTensorOps require_grad methods to avoid dequantizing

* Add Dequantize mapper docs

* Remove dead code
2024-07-15 08:20:37 -04:00
Guillaume Lagrange c30ffcf6ac
Enable optimized handling of bytes (#2003)
* Enable optimized handling of bytes

* Implement byte buffer de/serialization

* Use serde_bytes w/ alloc (no_std compatible)
2024-07-11 07:48:43 -04:00
Louis Fortier-Dubois 69be99b802
Cube: Matmul tiling (#1994) 2024-07-09 12:43:13 -04:00
Dilshod Tadjibaev e8b915a2da
Enhance slice operation to support more range variation (#1989)
* Enhance slice operation to support more range variation

* Fix doc clippy

* Fixed doc test

* Fix flipped attribute names

* Fix clippy
2024-07-08 13:34:25 -05:00
Guillaume Lagrange c0211e2f94
Add static tensor quantization (#1963)
* Add QuantizationBackend, QTensorOps and QTensor

* Refactor QTensorOps as part of Backend trait

* Add tensor dequantize, QFloat dtype and default affine/symmetric quant

* Add ndarray default quantization implementation

* Fix clippy

* Add rayon parallel iter

* Add quantization operations to book

* Add q_shape and q_device ops to avoid converting the tensor just to get attributes

* Implement autodiff grad ops

* Mark autodiff todo for QAT

* Remove note

* Add q_inner and q_from_inner
2024-07-08 10:16:58 -04:00
nathaniel 882a27c52c Revert "Revert "Implement 3D and transposed 3D convolutions. (#1945)""
This reverts commit b8b47ea6e6.
2024-07-05 18:57:01 -04:00
nathaniel b8b47ea6e6 Revert "Implement 3D and transposed 3D convolutions. (#1945)"
This reverts commit d696d74e3d.
2024-07-05 09:40:32 -04:00
Guillaume Charifi d696d74e3d
Implement 3D and transposed 3D convolutions. (#1945)
* Implement 3D and transposed 3D convolutions.

* Merge changes from onnx-ir #1921 pr

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-07-02 17:54:35 -05:00
Dilshod Tadjibaev 2bb76283ff
Improve pickle (CandleTensor) conversions to NestedValue (#1944)
* Manually serialize tensor - fixes #1773

* Rename `value` to `bytes`
2024-07-02 08:34:19 -04:00
Arthur Brussee 849c8f453b
Consistent sync/async handling, allow more functions to be async for wasm. (#1936) 2024-07-02 08:25:28 -04:00
Logan B. Nielsen 3a9367de73
remove manual option matching (#1948) 2024-07-01 10:44:10 -04:00
Guillaume Lagrange cdd1fa1672
Refactor tensor data (#1916)
* Move distribution to module

* Add new TensorData with serialization support

* Implement display and from for TensorData

* Add missing Cargo.lock

* Add missing bytemuck feature

* Add zeros, ones, full and random TensorData methods

* Refactor Data -> TensorData usage

* Fix tests

Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type.

* Remove commented line

* Fix import

* Add record-backward-compat

* Remove dim const generic from TensorData

* Support NestedValue de/serialization with TensorData

* Fix burn-jit tests

* Remove eprinln

* Refactor onnx import to use TensorData

* Fix tch from_data

* Fix nested value serialization for u8

* Fix missing import

* Fix reduce min onnx test

* Fix deprecated attribute

* Remove shape getter

* Remove strict assert in tests

* Add tensor data as_bytes

* Add tensor check for rank mismatch

* Fix typo (dimensions plural)

* Fix error message

* Update book examples with from_data and fix Display impl for TensorData

* Add deprecation note
2024-06-26 20:22:19 -04:00
Guillaume Lagrange 8071b637b8
Fix conv2d_weight_grad_groups (#1891) 2024-06-17 09:24:33 -04:00
Guillaume Lagrange 525244062f
Implement `Element` for `bool` (#1878)
* Element already implements One

* Add element module

* Add our own traits for Zero, One and ToPrimitive to support bool Element

* Fix typo

* Add basic tests for ToPrimitive with expected values

* The most important change of all

* Remove One + Zero identities

* Move zero/one outside mapv + refactor ToPrimitive -> ToElement trait

* Add num-traits to NOTICES.md
2024-06-14 09:02:38 -04:00
George b71c300638
Feat: Add `movedim` tensor operator (#1876)
*  (burn-tensor): add movedim function to tensor API

---------

Co-authored-by: Georgy Andreev <g.andreev@insilicomedicine.com>
2024-06-14 09:01:38 -04:00
Jonathan Richard 5de1517232
Add documentation to burn core nn (#1746)
* Updated documentation for unfold4d

Added links between the struct and the config. Added a link to the related burn_tensor function in the documentation for the forward function.

* Changing nn relu module documentation to functional api

Removing the formula for relu from the module API to the functional API,
citing a paper relevant to relu
and mentionning the functional API in the module API

* Linking gelu module API documentation to functional API documentation

* Linear module : adding documentation

Adding documentation to the Linear module
mentionning that LinearConfig struct
should be used when creating a Linear Layer

Also adding links to the documentation that points people toward
the right path

* Updated documentation for dropout

Added links between the struct and the config. Added a link to the struct in the forward function for more info.

* embedding + swiglu

* RotaryEncodying : adding documentation

Adding documentation stating the RotaryEncoding should be created using a RotaryEncodingConfig

* prelu: adding documentation

Adding documentation to the prelu module:
- Linking forward function documentation to the functional API
- Citing the first paper to mention prelu
- Adding documentation saying that prelu layer should be created using PReluConfig

* pos_encoding: adding documentation

* Updated documentation for mha

Added links for more info. Added shape info at some places.

* docs: Add documentation for Gru module

Provide documentation for the Gru module, including its configuration and usage. Include a link to the paper that introduced the Gated Recurrent Unit (GRU) and specify that the module should be created using GruConfig. Also, mention that the forward function returns a state tensor with specific dimensions.

* burn-core-nn-transformers: adding documentation

Adding documentation:
- Says to use config to create the layers
- Add mathematical formula to the pwff forward pass
- Add citation in the pwff to the "Attention is all you need" paper

* Updated documentation: ConvTranspose1d and ConvTranspose2d

* docs: Add documentation for Lstm and BiLstm modules

Provide documentation for the Lstm and BiLstm modules, including their configurations and usage. Include links to the papers that introduced Long Short-Term Memory (LSTM) and Bidirectional LSTM. Specify that the modules should be created using LstmConfig and BiLstmConfig respectively.

* docs: Update documentation for ConvTranspose1d and ConvTranspose2d modules

* loss: Adding documenntation to the loss layers

Adding documentation stating to use the config to create the layer

* chore: Refactor Conv1d module imports and update documentation

* docs: Add documentation for AdaptiveAvgPool1d and AdaptiveAvgPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Refactor Conv1d module imports and update documentation

* chore: Refactor Conv2d module imports and update documentation

* Add documentation for AvgPool1d and AvgPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Add documentation for MaxPool1d and MaxPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Add documentation for leaky_relu and removed Config generic

Added references to the burn_tensor associated functions. Added links between the struct and the config. Removed the backend generic from the config since it's not needed (might be a breaking change).

* refactor: Update BatchNormConfig initialization and add documentation.

* Added link to config in embedding struct documentation

* refactor: Update GroupNormConfig initialization and add documentation

* refactor: Update InstanceNormConfig initialization and add documentation

* feat: Update LayerNormConfig initialization and add documentation

* refactor: Update RmsNormConfig initialization and add documentation

* fixed: removed #derive accidentally

* Added missing backticks in pools' shapes

* Format nn doc

* Make config fields public in nn modules

* Update import statements in nn modules

Changed burn_tensor imports to crate::tensor

* Update import statements in nn modules' tests

Changed burn_tensor imports to crate::tensor

* breaking change refactor: Update GroupNormConfig and InstanceNormConfig initialization

* Make SwiGlu fields public

* grammar

* slashes

* input tensors grouping

* copy-pasta mistake

* a not an >:I

* Capitalization

* better desc

* math 'n ticks

* group_norm functional implementation

* removed the ... struct

* decoder typo

* fmt

* referring to private fn in docs

---------

Co-authored-by: Thierry Cantin-Demers <piertcd@gmail.com>
Co-authored-by: mepatrick73 <pameu17@ulaval.ca>
2024-06-13 12:50:21 -04:00
Arthur Brussee c873d87ac8
Add option to flush queue instead of waiting for completion. (#1864)
* Make sync_type an option on sync instead of adding submit
2024-06-13 09:56:08 -04:00
Louis Fortier-Dubois de5b681b18
Cube: Vectorization + simple matmul implementation (#1866) 2024-06-07 14:05:51 -04:00
Jonas Kantic fba1e27e0c
Remainder operator (#1726)
* Adds remainder ops implementation for Tensor.

* Adds test for % operator.
2024-06-01 16:47:02 -05:00
McArthur a2ad424fc8
Indices Operator (#1735) 2024-05-29 09:05:31 -04:00
Guillaume Lagrange e4836241e1
Fix `DataSerialize` conversion for elements of the same type (#1832) 2024-05-28 18:12:44 -04:00
Mathias Insley 81ecd14f83
Feat/squeeze dims (#1779) 2024-05-22 07:53:51 -04:00
Mathias Insley 9c5b07c833
Squeeze Onnx Import (#1753) 2024-05-17 12:00:34 -04:00
Ahmed Yarub Hani Al Nuaimi 10737527d8
#1747 Upgrade Rust dependencies (#1748)
* #1747
Upgrade Rust dependencies

* Revert upgrade for tch

The update of tch on windows gives an error:

INTEL MKL ERROR: The specified module could not be found. mkl_vml_avx2.1.dll.
Intel MKL FATAL ERROR: cannot load mkl_vml_avx2.1.dll or mkl_vml_def.1.dll.

* Keep only .cargo/config.toml file which works with rust > 1.75

---------

Co-authored-by: Sylvain Benner <sylvain@benner.online>
2024-05-10 16:25:19 -04:00
Thierry Cantin-Demers b09d8431df
Fix Cargo.toml repository links (#1749)
* Fix wgpu github link

* Fix burn-train repo link

* Fix burn-tensor github repo

* Fix burn-tensor repo link

* Fix remaining repo links in crates Cargo.toml

---------

Co-authored-by: Jonathan Richard <47578360+jwric@users.noreply.github.com>
2024-05-09 15:40:05 -04:00
Nathaniel Simard 5d959e2884
[Fusion] Support multi-precision fusion (#1718) 2024-05-02 18:22:56 -04:00
Louis Fortier-Dubois 2e4c82fa64
Fix repeat for dims > 1 (#1713) 2024-05-01 09:11:38 -04:00