Commit Graph

37 Commits

Author SHA1 Message Date
syl20bnr 8e78106680 Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
Genna Wingert a01004dd4a
Add Hard sigmoid activation function (#2112)
* Add Hard Sigmoid activation function

* Add ONNX import conversion for HardSigmoid

* Update supported operators list

* Update book

* Make test comparison approximate to eliminate precision issues

* Add burn-candle test

* Fix name in E2E test generator
2024-08-07 13:01:42 -05:00
Dilshod Tadjibaev cd848b1c94
Add is_nan and contains_nan tensor ops (#2088)
* Add is_nan and contains_nan tensor ops

* Enable nan test for burn-candle

* Disabling tests due to #2089
2024-08-06 12:16:12 -05:00
mepatrick73 f7639bd35a
Repeat operation (#2090)
* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error
2024-08-02 20:33:47 -04:00
Guillaume Lagrange 0d5025edbb
Refactor tensor quantization for q_* ops (#2025)
* Move QuantizationScheme to burn-tensor

* Refactor QuantizedTensorPrimitive to include the quantization strategy

* Fix QFloat tensor data display

* Refactor quantization methods to use scheme and qparams (on backend device)

* Fix clippy

* Fix fmt

* Add qtensor primitive tests
2024-07-19 10:39:50 -04:00
Guillaume Lagrange 3afff434bd
Module weight quantization (#2000)
* Add q_into_data and q_reshape

* Fix tch quantize f16 and q_into_data

* Convert to actual dtype/kind in dequantize

* Add module quantization and q_from_data

* Fix clippy

* Add documentation

* Handle deserialize data conversion

* Fix typo

* Add calibration tests

* Fix clippy precision

* Add QTensorOps require_grad methods to avoid dequantizing

* Add Dequantize mapper docs

* Remove dead code
2024-07-15 08:20:37 -04:00
Guillaume Lagrange c0211e2f94
Add static tensor quantization (#1963)
* Add QuantizationBackend, QTensorOps and QTensor

* Refactor QTensorOps as part of Backend trait

* Add tensor dequantize, QFloat dtype and default affine/symmetric quant

* Add ndarray default quantization implementation

* Fix clippy

* Add rayon parallel iter

* Add quantization operations to book

* Add q_shape and q_device ops to avoid converting the tensor just to get attributes

* Implement autodiff grad ops

* Mark autodiff todo for QAT

* Remove note

* Add q_inner and q_from_inner
2024-07-08 10:16:58 -04:00
nathaniel 882a27c52c Revert "Revert "Implement 3D and transposed 3D convolutions. (#1945)""
This reverts commit b8b47ea6e6.
2024-07-05 18:57:01 -04:00
nathaniel b8b47ea6e6 Revert "Implement 3D and transposed 3D convolutions. (#1945)"
This reverts commit d696d74e3d.
2024-07-05 09:40:32 -04:00
Guillaume Charifi d696d74e3d
Implement 3D and transposed 3D convolutions. (#1945)
* Implement 3D and transposed 3D convolutions.

* Merge changes from onnx-ir #1921 pr

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-07-02 17:54:35 -05:00
Arthur Brussee 849c8f453b
Consistent sync/async handling, allow more functions to be async for wasm. (#1936) 2024-07-02 08:25:28 -04:00
Guillaume Lagrange cdd1fa1672
Refactor tensor data (#1916)
* Move distribution to module

* Add new TensorData with serialization support

* Implement display and from for TensorData

* Add missing Cargo.lock

* Add missing bytemuck feature

* Add zeros, ones, full and random TensorData methods

* Refactor Data -> TensorData usage

* Fix tests

Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type.

* Remove commented line

* Fix import

* Add record-backward-compat

* Remove dim const generic from TensorData

* Support NestedValue de/serialization with TensorData

* Fix burn-jit tests

* Remove eprinln

* Refactor onnx import to use TensorData

* Fix tch from_data

* Fix nested value serialization for u8

* Fix missing import

* Fix reduce min onnx test

* Fix deprecated attribute

* Remove shape getter

* Remove strict assert in tests

* Add tensor data as_bytes

* Add tensor check for rank mismatch

* Fix typo (dimensions plural)

* Fix error message

* Update book examples with from_data and fix Display impl for TensorData

* Add deprecation note
2024-06-26 20:22:19 -04:00
Arthur Brussee c873d87ac8
Add option to flush queue instead of waiting for completion. (#1864)
* Make sync_type an option on sync instead of adding submit
2024-06-13 09:56:08 -04:00
Guillaume Lagrange e4836241e1
Fix `DataSerialize` conversion for elements of the same type (#1832) 2024-05-28 18:12:44 -04:00
Ahmed Yarub Hani Al Nuaimi 10737527d8
#1747 Upgrade Rust dependencies (#1748)
* #1747
Upgrade Rust dependencies

* Revert upgrade for tch

The update of tch on windows gives an error:

INTEL MKL ERROR: The specified module could not be found. mkl_vml_avx2.1.dll.
Intel MKL FATAL ERROR: cannot load mkl_vml_avx2.1.dll or mkl_vml_def.1.dll.

* Keep only .cargo/config.toml file which works with rust > 1.75

---------

Co-authored-by: Sylvain Benner <sylvain@benner.online>
2024-05-10 16:25:19 -04:00
Thierry Cantin-Demers b09d8431df
Fix Cargo.toml repository links (#1749)
* Fix wgpu github link

* Fix burn-train repo link

* Fix burn-tensor github repo

* Fix burn-tensor repo link

* Fix remaining repo links in crates Cargo.toml

---------

Co-authored-by: Jonathan Richard <47578360+jwric@users.noreply.github.com>
2024-05-09 15:40:05 -04:00
Sylvain Benner c579686a8a
Move HandleContainer and Tensor Ops descriptions from burn-fusion to burn-tensor (#1654)
* Move HandlerContainer and Tensor Ops description to burn-tensor

Move HandleContainer and Tensor operations descriptions to burn-tensor crate.
Removed the FusionDevice and replaced it with a DeviceOps trait bound to Backend::Device.

For now added modules to burn-tensor are excluded from no-std as they rely on Arc.

* [burn-tensor] Flatten module hierarchy for tensor representation

+ Add new repr feature to cargo file.

* Remove prefix on dosctring

* [burn-fusion] Require default features of burn-tensor
2024-04-23 11:27:54 -04:00
Dilshod Tadjibaev 2a721a9d0c
Enable native sign operation for Candle backend (#1647)
* Enable native sign operation for Candle backend

* Use fixed revision
2024-04-17 09:07:56 -04:00
Mathias Insley 7377bbe31c
Feat/remainder (#1597)
* Add remainder_scalar op to numeric trait and associated int/float functions

* Update burn-tch crate

* Update ndarray crate

* Update jit crate

* Update candle crate

* Update fusion crate

* Update autodiff crate

* Forgot float.rs for fusion

* Add burn-tensor tests

* Redirect to the pre-existing modulus op

* Fix sign

* Remove mut from burn-tch

* Use sign trick to make wgpu backend work

* Add more unit tests in to cover bases

* Naming fix for burn-fusion

* Update tests w/PyTorch link

* Use different WGSL instructions for remainder

* Redirect to remainder Operator instead of modulo

* Revert Modulo in instruction.rs
2024-04-16 08:35:20 -04:00
Sylvain Benner e303e31c8b
Bump next version of Burn to 0.14.0 (#1618) 2024-04-12 17:14:45 -04:00
Guillaume Lagrange 264c167c11
Update licenses symlinks (#1613) 2024-04-12 14:43:58 -04:00
Nathaniel Simard ff844b1667
Fix candle backend sync (#1579)
* Fix candle backend sync

* tch mps sync

* clippy

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-04-12 12:15:50 -04:00
Nathaniel Simard 40a26bd2ea
Feat/backend bridge (#1529) 2024-03-26 19:24:45 -04:00
Dilshod Tadjibaev 6feda90a8c
Tensor expand operator (#1508)
* Improve CI cache - remove burn-tch artifacts

* PyTorch config deserializer from .pt file

* Update pytorch-model.md

* WIP

* Rename broadcast_to to expand

* Rename broadcast_to expand file

* Implemented fusion backend and fix bugs

* Remove old files

* Remove unused state

* Rename to the correct op name

* Add missing comment

* Fix expand check function doc

* Rename the leftover names

* Rename leftover names
2024-03-22 16:33:53 -05:00
carrotflakes 8911093b88
Add `flip` tensor operator (#1468) 2024-03-18 20:33:39 -05:00
Dilshod Tadjibaev 7a98b2f663
Add prod and prod_dim tensor ops (#1460) 2024-03-12 14:00:02 -05:00
Dilshod Tadjibaev 3f7e6bd5bc
Add `sign` tensor operator (#1446) 2024-03-11 10:39:30 -05:00
Aasheesh Singh 0c92c8c8eb
Autodiff/training support for Nearest Interpolation (#1414)
Add training support for nearest interpolation

---------

Co-authored-by: yurzhang <yurzhang.oi@gmail.com>
Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-03-06 00:12:05 -05:00
Dilshod Tadjibaev c7834e4658
Tensor `permute` operator (#1410) 2024-03-05 12:29:13 -05:00
Dilshod Tadjibaev 4ed90a988e
Add `bool()` op for numerical tensor (#1402)
Fixes #1395
2024-03-04 12:39:17 -06:00
Guillaume Lagrange 16d7666611
Add `argwhere` and `nonzero` boolean tensor ops (#1394)
* Add argwhere and nonzero bool tensor ops

* Fix wasm build

* Add missing vec

* Fix wasm cfg placement

* Fix comment
2024-03-04 08:33:59 -05:00
yurzhang 7d44f0b2d7
Interpolate tensor operation (Inference Only) (#1246)
* squash

feat: bilinear interpolation for tch, ndarray and wgpu backend

fix: reduce test case size to avoid exceeding floating-point precision limits

feat: support nearest-neighbor interpolation for ndarray backend

feat: support nearest-neighbor interpolation for wgpu backend

feat: support fusion backend

fix: no-std support

build: upgrade dependencies

* feat: bicubic interpolation for ndarray backend

* fix: test case precision

* feat: bicubic interpolation for wgpu backend

* Update Cargo.lock

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
Co-authored-by: Aasheesh Singh <20820983+ashdtu@users.noreply.github.com>
2024-03-02 12:01:35 -06:00
Dilshod Tadjibaev d43a0b3f90
Add is_close and all_close tensor operators (#1389)
* Add is_close and all_close tensor operators

* Fix broken build issues

* Fix the table

* Add tests to candle
2024-03-01 15:37:14 -06:00
Guillaume Lagrange 4efc683df4
Upgrade to candle 0.4.1 (#1382)
* Fix python main entrypoint in book example

* Remove candle windows safeguards (#1178)

* Bump candle-core from 0.3.3 to 0.4.1

* Remove windows current known issue
2024-02-29 11:29:11 -06:00
Mathias Insley bb5e6faff2
Feat/autotune int ops (#1136)
* Add int_random to int tensor ops

* Int random for tch backend

* Int random for burn-fusion

* int random for autodiff

* Int random for candle backend

* Int random for ndarray backend

* Int random for wgpu backend

* Merge imports

* Typo

* Shader file for int uniform distribution

* Create AutotuneOperationSet and public int_sum_dim_autotune

* Adjust bounds to 0..10

* Create uniform_int_kernel, unit tests, use new kernel

* Reduction kernels for regular and shared memory sum_dim int operations

* Macro that accomadates wgpu IntElement

* Add autotuning to int_mean_dim

* Use correct macro for Int autotuning

* Add int_mean_dim_shared_memory

* Add int_mean_dim and unit test

* Create autotunables for mean_dim

* Run fmt

* Remove comment

* Finish resolving merge conflict, fix doc

* Make the element trait bound a parameter to reduce_tune_ops macro

* Update book

* Fix requested change

* Change range to [0, 255] and update test accordingly

* Forgot to include candle in last commit

* Fix comment

* Use correct int autotune for mean dim

* Fix typo- not sure how this passed earlier

* Resolve syntax issues from merge

* Fix cast_float

* Saving here

* Continue fixing merge conflicts, all tests pass locally

* Run fmt

* Change cast_float to cast_u32_to_float

* Make uniform_int_inner_loop safer

* Be even more explicit about u32 casts

* Skip an intermediate step and cast directly to u32

* Replace JitElement + Element with IntElement

* Run fmt

* This should fix the CI

* This time for sure
2024-02-26 14:53:21 -05:00
Arjun31415 8e23057c6b
Feature Addition: PRelu Module (#1328) 2024-02-24 10:24:22 -05:00
Sylvain Benner 4427768570
[refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00