Commit Graph

340 Commits

Author SHA1 Message Date
Dirley Jordan 9e6777d6a5
Add ReduceProd ONNX Import (#1955)
* Preliminary ReduceProd Support

* Add comma to keep formatter happy

* Give test results a 0.001 tolerance to account for floating-point multiplication

* Reformat assersions

* Correctly mark panic conditions in op_configuration
2024-07-02 09:05:28 -04:00
Dilshod Tadjibaev 2bb76283ff
Improve pickle (CandleTensor) conversions to NestedValue (#1944)
* Manually serialize tensor - fixes #1773

* Rename `value` to `bytes`
2024-07-02 08:34:19 -04:00
Nathaniel Simard 82a883a57d
Feat/cube/fma (#1947) 2024-07-02 08:32:39 -04:00
Nathaniel Simard cb6b5e7183
Feat/cube/cooperative matrix-multiply and accumulate. (#1943) 2024-07-02 08:31:00 -04:00
Nathaniel Simard ad81a997af
Perf: cube reuse shape and strides (#1939) 2024-07-02 08:28:32 -04:00
Arthur Brussee 849c8f453b
Consistent sync/async handling, allow more functions to be async for wasm. (#1936) 2024-07-02 08:25:28 -04:00
Logan B. Nielsen 3a9367de73
remove manual option matching (#1948) 2024-07-01 10:44:10 -04:00
Guillaume Lagrange e753b0c4e7
Fix output tensor dtype (#1938) 2024-07-01 10:27:31 -04:00
Roy Varon a7efc102b9
Replaced `str` with `Path` (#1919)
* replaced str with Path

* minor change (Path to AsRef<Path>)

* fixed clippy lint
2024-06-29 18:17:59 -05:00
Dilshod Tadjibaev 98a58c867d
Print module - implement module display for remaining modules (part2) (#1933) 2024-06-28 08:37:40 -04:00
Nathaniel Simard 1ae1c03b2d
Refactor/cube/mutability (#1934) 2024-06-27 16:03:23 -04:00
Guillaume Lagrange cdd1fa1672
Refactor tensor data (#1916)
* Move distribution to module

* Add new TensorData with serialization support

* Implement display and from for TensorData

* Add missing Cargo.lock

* Add missing bytemuck feature

* Add zeros, ones, full and random TensorData methods

* Refactor Data -> TensorData usage

* Fix tests

Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type.

* Remove commented line

* Fix import

* Add record-backward-compat

* Remove dim const generic from TensorData

* Support NestedValue de/serialization with TensorData

* Fix burn-jit tests

* Remove eprinln

* Refactor onnx import to use TensorData

* Fix tch from_data

* Fix nested value serialization for u8

* Fix missing import

* Fix reduce min onnx test

* Fix deprecated attribute

* Remove shape getter

* Remove strict assert in tests

* Add tensor data as_bytes

* Add tensor check for rank mismatch

* Fix typo (dimensions plural)

* Fix error message

* Update book examples with from_data and fix Display impl for TensorData

* Add deprecation note
2024-06-26 20:22:19 -04:00
mepatrick73 1c7780aaac
Feat/dynamic small pool (#1931) 2024-06-26 15:42:04 -04:00
Nathaniel Simard f9ec2e1006
Handle visibility in cube (#1929) 2024-06-26 12:57:47 -04:00
Nathaniel Simard d772a1cfd5
Fix: launch without generics (#1932) 2024-06-26 12:57:32 -04:00
mepatrick73 4c9097030f
Perf/dynamic mm slice adressing (#1917)
* basic implementation of virtual memory addressing for fast index + merging (there is a bug with slice padding
2024-06-25 18:16:46 -04:00
Nathaniel Simard 2fbc4628f3
Feat/cube/array assign ops (#1914) 2024-06-25 09:55:55 -04:00
Dilshod Tadjibaev 2c51615471
Print model structure like with PyTorch - Part 1 (#1912) 2024-06-25 09:23:10 -04:00
Nathaniel Simard a5dfb87828
Feat/comptime expr (#1910)
* Support comptime expressions

* Add test

* Cleanup

* Fix
2024-06-20 16:00:22 -04:00
Nathaniel Simard efc13d9a38
Feat/cube/compile error (#1909) 2024-06-19 17:21:32 -04:00
Nathaniel Simard d50bac165e
feat cube support Array (#1907) 2024-06-19 17:03:02 -04:00
Arthur Brussee 14d1bbba64
Do not use default burn-compute features unless enabled. (#1908) 2024-06-19 10:12:11 -04:00
Nathaniel Simard 560d77d154
Doc: Improve module to_device/fork docs (#1901) 2024-06-18 16:45:38 -04:00
Nathaniel Simard e758fd43db
Fix: constant record loading (#1902) 2024-06-18 16:45:21 -04:00
Justin Restivo 263add23a0
Tanh nn wrapper (#1903) 2024-06-18 16:45:04 -04:00
phenylshima f8a7c54272
feat: Make RetroForward public (#1905) 2024-06-18 16:44:32 -04:00
jachym 96468fc3c9
feat: added reduce min onnx import (#1894) 2024-06-18 09:04:24 -04:00
Nathaniel Simard 4f6db974a1
Perf/dynamic mm (#1906) 2024-06-18 08:41:07 -04:00
Guillaume Lagrange 8071b637b8
Fix conv2d_weight_grad_groups (#1891) 2024-06-17 09:24:33 -04:00
github-actions[bot] a04da9a285
Combined PRs (#1900)
* Bump cudarc from 0.11.4 to 0.11.6

Bumps [cudarc](https://github.com/coreylowman/cudarc) from 0.11.4 to 0.11.6.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.11.4...v0.11.6)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump derive_more from 0.99.17 to 0.99.18

Bumps [derive_more](https://github.com/JelteF/derive_more) from 0.99.17 to 0.99.18.
- [Release notes](https://github.com/JelteF/derive_more/releases)
- [Changelog](https://github.com/JelteF/derive_more/blob/v0.99.18/CHANGELOG.md)
- [Commits](https://github.com/JelteF/derive_more/compare/v0.99.17...v0.99.18)

---
updated-dependencies:
- dependency-name: derive_more
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tokio from 1.37.0 to 1.38.0

Bumps [tokio](https://github.com/tokio-rs/tokio) from 1.37.0 to 1.38.0.
- [Release notes](https://github.com/tokio-rs/tokio/releases)
- [Commits](https://github.com/tokio-rs/tokio/compare/tokio-1.37.0...tokio-1.38.0)

---
updated-dependencies:
- dependency-name: tokio
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump regex from 1.10.4 to 1.10.5

Bumps [regex](https://github.com/rust-lang/regex) from 1.10.4 to 1.10.5.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/compare/1.10.4...1.10.5)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-06-17 09:06:01 -04:00
Arthur Brussee ac9f942a46
Remove GraphicsAPI generic for WgpuRuntime (#1888) 2024-06-17 09:04:25 -04:00
Joshua Ferguson eead748e90
add dependency management for python (#1887) 2024-06-17 09:00:38 -04:00
Louis Fortier-Dubois 8bf1cd60dc
Cube: variable reusability + refactor in cube macros (#1885) 2024-06-14 11:20:25 -04:00
Guillaume Lagrange 525244062f
Implement `Element` for `bool` (#1878)
* Element already implements One

* Add element module

* Add our own traits for Zero, One and ToPrimitive to support bool Element

* Fix typo

* Add basic tests for ToPrimitive with expected values

* The most important change of all

* Remove One + Zero identities

* Move zero/one outside mapv + refactor ToPrimitive -> ToElement trait

* Add num-traits to NOTICES.md
2024-06-14 09:02:38 -04:00
George b71c300638
Feat: Add `movedim` tensor operator (#1876)
*  (burn-tensor): add movedim function to tensor API

---------

Co-authored-by: Georgy Andreev <g.andreev@insilicomedicine.com>
2024-06-14 09:01:38 -04:00
Arthur Brussee 47a81270e1
Make autodiff compile on wasm (#1889) 2024-06-14 08:12:14 -04:00
Nathaniel Simard 5e58ae1a02
Refactor the tuner to be used standalone (#1884)
* Refactor the tuner to be used standalone

* Add a name for the autotune cache

* Fix tests

* Fix typo
2024-06-13 13:23:58 -04:00
Jonathan Richard 5de1517232
Add documentation to burn core nn (#1746)
* Updated documentation for unfold4d

Added links between the struct and the config. Added a link to the related burn_tensor function in the documentation for the forward function.

* Changing nn relu module documentation to functional api

Removing the formula for relu from the module API to the functional API,
citing a paper relevant to relu
and mentionning the functional API in the module API

* Linking gelu module API documentation to functional API documentation

* Linear module : adding documentation

Adding documentation to the Linear module
mentionning that LinearConfig struct
should be used when creating a Linear Layer

Also adding links to the documentation that points people toward
the right path

* Updated documentation for dropout

Added links between the struct and the config. Added a link to the struct in the forward function for more info.

* embedding + swiglu

* RotaryEncodying : adding documentation

Adding documentation stating the RotaryEncoding should be created using a RotaryEncodingConfig

* prelu: adding documentation

Adding documentation to the prelu module:
- Linking forward function documentation to the functional API
- Citing the first paper to mention prelu
- Adding documentation saying that prelu layer should be created using PReluConfig

* pos_encoding: adding documentation

* Updated documentation for mha

Added links for more info. Added shape info at some places.

* docs: Add documentation for Gru module

Provide documentation for the Gru module, including its configuration and usage. Include a link to the paper that introduced the Gated Recurrent Unit (GRU) and specify that the module should be created using GruConfig. Also, mention that the forward function returns a state tensor with specific dimensions.

* burn-core-nn-transformers: adding documentation

Adding documentation:
- Says to use config to create the layers
- Add mathematical formula to the pwff forward pass
- Add citation in the pwff to the "Attention is all you need" paper

* Updated documentation: ConvTranspose1d and ConvTranspose2d

* docs: Add documentation for Lstm and BiLstm modules

Provide documentation for the Lstm and BiLstm modules, including their configurations and usage. Include links to the papers that introduced Long Short-Term Memory (LSTM) and Bidirectional LSTM. Specify that the modules should be created using LstmConfig and BiLstmConfig respectively.

* docs: Update documentation for ConvTranspose1d and ConvTranspose2d modules

* loss: Adding documenntation to the loss layers

Adding documentation stating to use the config to create the layer

* chore: Refactor Conv1d module imports and update documentation

* docs: Add documentation for AdaptiveAvgPool1d and AdaptiveAvgPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Refactor Conv1d module imports and update documentation

* chore: Refactor Conv2d module imports and update documentation

* Add documentation for AvgPool1d and AvgPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Add documentation for MaxPool1d and MaxPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Add documentation for leaky_relu and removed Config generic

Added references to the burn_tensor associated functions. Added links between the struct and the config. Removed the backend generic from the config since it's not needed (might be a breaking change).

* refactor: Update BatchNormConfig initialization and add documentation.

* Added link to config in embedding struct documentation

* refactor: Update GroupNormConfig initialization and add documentation

* refactor: Update InstanceNormConfig initialization and add documentation

* feat: Update LayerNormConfig initialization and add documentation

* refactor: Update RmsNormConfig initialization and add documentation

* fixed: removed #derive accidentally

* Added missing backticks in pools' shapes

* Format nn doc

* Make config fields public in nn modules

* Update import statements in nn modules

Changed burn_tensor imports to crate::tensor

* Update import statements in nn modules' tests

Changed burn_tensor imports to crate::tensor

* breaking change refactor: Update GroupNormConfig and InstanceNormConfig initialization

* Make SwiGlu fields public

* grammar

* slashes

* input tensors grouping

* copy-pasta mistake

* a not an >:I

* Capitalization

* better desc

* math 'n ticks

* group_norm functional implementation

* removed the ... struct

* decoder typo

* fmt

* referring to private fn in docs

---------

Co-authored-by: Thierry Cantin-Demers <piertcd@gmail.com>
Co-authored-by: mepatrick73 <pameu17@ulaval.ca>
2024-06-13 12:50:21 -04:00
Louis Fortier-Dubois 4393b336bc
clippy on rust update (#1886) 2024-06-13 12:15:15 -04:00
Arthur Brussee c873d87ac8
Add option to flush queue instead of waiting for completion. (#1864)
* Make sync_type an option on sync instead of adding submit
2024-06-13 09:56:08 -04:00
Mitchell Mosure 71bd5efbfa
feat: resize onnx import (#1863)
* feat: resize onnx import

* fix: resize import proc macro output

* fix: lint

* fix: simplify resize onnx

* fix: onnx-tests passing

* feedback: remove dead code and resolve merge conflicts
2024-06-11 13:22:33 -04:00
jachym 671ec8c679
feat: added slice onnx import (#1856)
* feat: added slice onnx import

* fix: axes, steps handling
2024-06-11 07:50:03 -04:00
github-actions[bot] dd60446946
Combined PRs (#1874)
* Bump cudarc from 0.11.0 to 0.11.4

Bumps [cudarc](https://github.com/coreylowman/cudarc) from 0.11.0 to 0.11.4.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.11.0...v0.11.4)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump clap from 4.5.4 to 4.5.6

Bumps [clap](https://github.com/clap-rs/clap) from 4.5.4 to 4.5.6.
- [Release notes](https://github.com/clap-rs/clap/releases)
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md)
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.4...v4.5.6)

---
updated-dependencies:
- dependency-name: clap
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tar from 0.4.40 to 0.4.41

Bumps [tar](https://github.com/alexcrichton/tar-rs) from 0.4.40 to 0.4.41.
- [Commits](https://github.com/alexcrichton/tar-rs/compare/0.4.40...0.4.41)

---
updated-dependencies:
- dependency-name: tar
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump strum_macros from 0.26.2 to 0.26.4

Bumps [strum_macros](https://github.com/Peternator7/strum) from 0.26.2 to 0.26.4.
- [Release notes](https://github.com/Peternator7/strum/releases)
- [Changelog](https://github.com/Peternator7/strum/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Peternator7/strum/commits)

---
updated-dependencies:
- dependency-name: strum_macros
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump zip from 2.1.2 to 2.1.3

Bumps [zip](https://github.com/zip-rs/zip2) from 2.1.2 to 2.1.3.
- [Release notes](https://github.com/zip-rs/zip2/releases)
- [Changelog](https://github.com/zip-rs/zip2/blob/master/CHANGELOG.md)
- [Commits](https://github.com/zip-rs/zip2/compare/v2.1.2...v2.1.3)

---
updated-dependencies:
- dependency-name: zip
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-06-10 16:22:08 -04:00
Joshua Ferguson effce28b72
Optimize argument handling and improve ONNX graph building (#1857)
* draft for alternative burn import design

* passes onnx test, fails to build example

* pushing to test example on main

* fixed the issue with the example

* passes the test now

* spring cleaning and minor code changes

* removed pub visibility from most graph_data fields and functions

* comment fixes

* went ahead and removed the constant check for now

* removed unused function arg
2024-06-10 14:06:54 -05:00
Louis Fortier-Dubois de5b681b18
Cube: Vectorization + simple matmul implementation (#1866) 2024-06-07 14:05:51 -04:00
Arthur Brussee 4b174a88bd
Get resources from server (#1861) 2024-06-06 17:33:57 -04:00
Arthur Brussee 75e26d03c3
Speedup client.create for small allocations. (#1858)
* Speedup client.create for small allocations.
2024-06-06 17:09:01 -04:00
Arthur Brussee 675f6b3280
Make Param.id public (#1859)
* Make Param.id public

* Remove extra comment.
2024-06-06 11:03:14 -04:00
Icekey d28183c7e4
LearnerBuilder "with_checkpointing_strategy" should use builder pattern (#1841) 2024-06-05 07:55:44 -04:00
Arthur Brussee e0a1094f89
Add a feature to initialize from an existing wgpu adapter/device/queue (#1788)
* Add a feature to initialize from an existing wgpu adapter/device/queue

This is useful when interacting with other wgpu applications (eg. displaying a burn tensor as a texture in egui). The existing devices are keyed by the wgpu Device ID. Alternatively they could be keyed per adapter which would be more inline with other burn WgpuDevice's (one per adapter), but also there's no real inherent reason to.

This also involves making Queue into an Arc. Alternatively, this could give up ownership of the queue, but it's helpful to be able to synchronize burn operations and custom wgpu operations.
2024-06-05 07:19:52 -04:00
mepatrick73 36ed65a5cd
Feat/dynamic mm basic implementation + small refactor (#1844) 2024-06-04 17:01:33 -04:00
Louis Fortier-Dubois c42abadfe9
Cube: CubeType (no launch) and Comptime::map (#1853) 2024-06-04 13:43:43 -04:00
jachym a5af19b959
feat: add sum onnx import (#1846) 2024-06-03 15:30:44 -05:00
Louis Fortier-Dubois 5edaeabcee
Feat/cube/struct support (#1842)
* struct support (receive, use and modify fields)

* support struct with generics

* expect instead of unwrap

* fmt

* rename struc

* fmt

* Clippy

* Fix launcher

* Support creating private cube type without generics

* Cleanup

* generics support

* clippy

* minor

* fmt

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-06-03 12:19:05 -04:00
Mathias Insley 92b0067693
Feat/gather import (#1843)
* Move and redirect GatherElements to new folders/nodes

* Create PyTorch script for gather

* Add onnx file for gather

* Add a gather test to onnx_tests

* Update gather.rs to use select

* Rename codegen test

* Update gather and gather_elements conversion functions

* Validate rank of input node and update output

* Add check for Gather
2024-06-03 08:28:32 -04:00
Jonas Kantic fba1e27e0c
Remainder operator (#1726)
* Adds remainder ops implementation for Tensor.

* Adds test for % operator.
2024-06-01 16:47:02 -05:00
jachym.putta 99e1ba4864
feat: expand onnx import (#1813)
* feat: added expand to import
2024-05-31 16:48:02 -05:00
jachym.putta 44f1053219
feat: added range onnx import (#1834)
* feat: added range onnx import

* fix: range input types
2024-05-31 16:40:54 -05:00
Nathaniel Simard 36d4bcd705
[Refactor - Breaking] Refactor cube operations with better names & Support subgroup operations (#1839) 2024-05-31 17:07:21 -04:00
will-maclean 13a6f84bc3
Feature/onnx argmax (#1814)
* pre-test

* implementing argmax for burn-import from onnx

* tidying

* fixing return types and tests

* addressing feedback

* only warn when select_last_index!=0
2024-05-31 14:46:09 -04:00
Louis Fortier-Dubois de0b49e4a3
Cube: Topology constants (#1838)
---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-05-30 12:03:30 -04:00
Louis Fortier-Dubois 61c9fdbbc8
Cube: cleaner use of topology values (#1835)
* constant keyword parsing

* works
2024-05-29 09:08:10 -04:00
McArthur a2ad424fc8
Indices Operator (#1735) 2024-05-29 09:05:31 -04:00
Louis Fortier-Dubois cacc764205
Cube: support for shared memory (#1831) 2024-05-29 08:22:04 -04:00
Guillaume Lagrange e4836241e1
Fix `DataSerialize` conversion for elements of the same type (#1832) 2024-05-28 18:12:44 -04:00
Louis Fortier-Dubois e61b026918
Cube: support method call + prettier tensor metadata (#1829) 2024-05-27 15:18:17 -04:00
Nathaniel Simard fd54a8b470
Add vectorization support into cube (#1830) 2024-05-27 14:21:29 -04:00
Louis Fortier-Dubois dc85daa1c6
Cube: support for return + conv2d early return (#1828) 2024-05-27 13:19:00 -04:00
Nathaniel Simard 15d2055de8
Feat/cube/launch (#1827) 2024-05-27 12:15:06 -04:00
Adrian Müller cccd96de48
Feat: Implement ONNX RandomUniform + RandomNormal in burn-import (#1806) 2024-05-27 10:07:04 -04:00
github-actions[bot] 85ba167582
Combined PRs (#1823)
* Bump cudarc from 0.10.0 to 0.11.0

Bumps [cudarc](https://github.com/coreylowman/cudarc) from 0.10.0 to 0.11.0.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.10.0...v0.11.0)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump libc from 0.2.154 to 0.2.155

Bumps [libc](https://github.com/rust-lang/libc) from 0.2.154 to 0.2.155.
- [Release notes](https://github.com/rust-lang/libc/releases)
- [Commits](https://github.com/rust-lang/libc/compare/0.2.154...0.2.155)

---
updated-dependencies:
- dependency-name: libc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump ratatui from 0.26.2 to 0.26.3

Bumps [ratatui](https://github.com/ratatui-org/ratatui) from 0.26.2 to 0.26.3.
- [Release notes](https://github.com/ratatui-org/ratatui/releases)
- [Changelog](https://github.com/ratatui-org/ratatui/blob/main/CHANGELOG.md)
- [Commits](https://github.com/ratatui-org/ratatui/compare/v0.26.2...v0.26.3)

---
updated-dependencies:
- dependency-name: ratatui
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump bytemuck from 1.15.0 to 1.16.0

Bumps [bytemuck](https://github.com/Lokathor/bytemuck) from 1.15.0 to 1.16.0.
- [Changelog](https://github.com/Lokathor/bytemuck/blob/main/changelog.md)
- [Commits](https://github.com/Lokathor/bytemuck/compare/v1.15.0...v1.16.0)

---
updated-dependencies:
- dependency-name: bytemuck
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump proc-macro2 from 1.0.83 to 1.0.84

Bumps [proc-macro2](https://github.com/dtolnay/proc-macro2) from 1.0.83 to 1.0.84.
- [Release notes](https://github.com/dtolnay/proc-macro2/releases)
- [Commits](https://github.com/dtolnay/proc-macro2/compare/1.0.83...1.0.84)

---
updated-dependencies:
- dependency-name: proc-macro2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-05-27 10:04:27 -04:00
Ikko Eltociear Ashimine 1c5e65ab26
docs: update README.md (#1810)
minor fix
2024-05-27 09:21:03 -04:00
Nathaniel Simard c7ad25ab60
Update cuda-jit (#1799) 2024-05-24 11:31:47 -04:00
Louis Fortier-Dubois 23c622a9f8
Feat/cube/remaining ops (#1807) 2024-05-24 09:48:34 -04:00
Justin Restivo 1670a71711
Fix burn-jit compile error (#1803) 2024-05-23 20:24:42 -04:00
jachym.putta ef4646c90f
feat: Greater + GreaterOrEqual onnx import (#1801) 2024-05-23 08:59:15 -04:00
jachym.putta 1f31e20ce8
feat: Less + LessOrEqual onnx import (#1800) 2024-05-23 08:04:44 -04:00
Louis Fortier-Dubois e39b4d2da0
refactor reduce into separate traits (#1798) 2024-05-22 16:01:27 -04:00
Louis Fortier-Dubois 033171920c
Cube: first ported kernel + comptime support + variable reuse + cleanup (#1797) 2024-05-22 14:08:21 -04:00
Guillaume Lagrange b466fd7606
Add seq start position when applying RoPE encoding (#1796) 2024-05-22 13:18:31 -04:00
jachym.putta 0918cf00c6
feat: added min onnx import (#1778) 2024-05-22 10:52:19 -04:00
Guillaume Lagrange 550086a5c1
Fix record nested value de/serialization (#1751) 2024-05-22 09:15:32 -04:00
Louis Fortier-Dubois 6137d42c10
fix prng bug during autotune (#1791) 2024-05-22 09:11:13 -04:00
jachym.putta 8c01444fc5
Adding max import (#1769)
* feat: add max import

* feat: implement the right max operation (hopefully)
2024-05-22 08:31:55 -04:00
Mathias Insley 81ecd14f83
Feat/squeeze dims (#1779) 2024-05-22 07:53:51 -04:00
Louis Fortier-Dubois 76fe0ed881
Refactor/cube/vectorization (#1781) 2024-05-19 13:20:55 -04:00
Louis Fortier-Dubois 499ff0dd26
Feat/enable cube cl (#1777)
* Ben WIP

* Compile burn-jit

* WGPU works

* Remove old code

* move language cube stuff

* cleaning up

* some import reworking

* remove cube reexport

* template feature flag in cube

* ci

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-05-19 10:55:04 -04:00
Mathias Insley 9c5b07c833
Squeeze Onnx Import (#1753) 2024-05-17 12:00:34 -04:00
Jonathan Richard 8de05e1419
Add configurable application logger to learner builder (#1774)
* refactor: add TracingSubscriberLogger trait and FileTracingSubscriberLogger struct

* Remove unused log module and renames, fmt

* Renamed tracing subscriber logger

* renamed to application logger installer

* book learner configuration update update

* fix typo

* unused import
2024-05-16 16:25:33 -04:00
Nathaniel Simard 7ab2ba1809
Feat/cubecl ir (#1776)
---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-05-16 15:08:53 -04:00
Louis Fortier-Dubois 542790e17e
CubeCL first iteration (#1756) 2024-05-15 10:24:37 -04:00
getumen e823338750
Add Clone trait to the `OptimizerAdaptor` and Clone implementations to the optimizers (#1770) 2024-05-15 09:18:09 -04:00
Ben Barber d3cd6c4928
Replace opaque return types in optim (#1767)
* update ARCHITECTURE.md links to project architecture section in contributor book

* replace opaque return type in optim
2024-05-13 22:21:20 -04:00
Nathaniel Simard 9dcec0b998
Refactor/jit fusion (#1750)
* Reads & Writes with index_ref

* WIP

* Fix operations

* Cleanup
2024-05-13 12:48:23 -04:00
Ahmed Yarub Hani Al Nuaimi 10737527d8
#1747 Upgrade Rust dependencies (#1748)
* #1747
Upgrade Rust dependencies

* Revert upgrade for tch

The update of tch on windows gives an error:

INTEL MKL ERROR: The specified module could not be found. mkl_vml_avx2.1.dll.
Intel MKL FATAL ERROR: cannot load mkl_vml_avx2.1.dll or mkl_vml_def.1.dll.

* Keep only .cargo/config.toml file which works with rust > 1.75

---------

Co-authored-by: Sylvain Benner <sylvain@benner.online>
2024-05-10 16:25:19 -04:00
Thierry Cantin-Demers b09d8431df
Fix Cargo.toml repository links (#1749)
* Fix wgpu github link

* Fix burn-train repo link

* Fix burn-tensor github repo

* Fix burn-tensor repo link

* Fix remaining repo links in crates Cargo.toml

---------

Co-authored-by: Jonathan Richard <47578360+jwric@users.noreply.github.com>
2024-05-09 15:40:05 -04:00
Arjun31415 5bbc5ea944
Added ONNX AvgPool1d (#1744) 2024-05-07 16:10:18 -05:00
Nathaniel Simard a6e3b4e81e
Fix select assign backward (#1739) 2024-05-07 11:37:43 -04:00
Sébastien Boisvert bd06b38fac
Refactor: replace trait TemplateKernel by existing trait JitKernel (#1737)
* Refactor: replace trait TemplateKernel by existing trait JitKernel

* Refactor: implement trait JitKernel for struct Kernel
2024-05-06 20:59:00 -04:00
Arjun31415 7f94f4c219
Add MaxPool1d ONNX Op(#1725) 2024-05-06 10:51:00 -05:00
Anton Blomström fb13503fa9
Add reduce sum onnx ops to burn imports (#1723) 2024-05-06 10:49:17 -05:00
Anton Blomström f8994e044c
Fix unstable tests when run concurrently (#1724) 2024-05-05 15:27:42 -05:00
Arjun31415 152509c378
PReLu ONNX import (#1721)
* added prelu onnx operator

* bug fix

* added onnx tests and burn codegen tests

* fix tests

* added prelu to supported onnx ops and add prelu to dim_inference
2024-05-04 13:45:42 -05:00
Louis Fortier-Dubois a8661a2f53
Autodiff Memory Management: BFS (#1710) 2024-05-03 09:45:21 -04:00
Nathaniel Simard 5d959e2884
[Fusion] Support multi-precision fusion (#1718) 2024-05-02 18:22:56 -04:00
Louis Fortier-Dubois 2e4c82fa64
Fix repeat for dims > 1 (#1713) 2024-05-01 09:11:38 -04:00
Dilshod Tadjibaev 3a02a54e55
Update SUPPORTED-ONNX-OPS.md (#1717)
gather ONNX was checked off but actually GatherElements should have been updated.
2024-05-01 08:02:59 -04:00
Dilshod Tadjibaev ff9e875321
ONNX debug improvements (#1712)
* Minor debug improvements

* Change warn to panic

* Log improvements
2024-04-30 16:36:55 -05:00
Nathaniel Simard 587b8f80b3
First draft CUDA runtime (#1685)
Initial cuda runtime crate with a WIP compiler.
2024-04-30 09:46:29 -04:00
Jonathan Merritt ab501431b1
Handle ndarray matmul broadcasting (#1679)
* Handle ndarray matmul broadcasting

- Use strides to map linear batch indices from
  the output back to the input arrays.

* Fix typos
2024-04-29 17:25:27 -05:00
Dilshod Tadjibaev 1cdceb590f
Skip updating shape for linear if not present (#1700) 2024-04-29 14:53:18 -05:00
WU Chen b387829731
Implement bidirectional LSTM (#1035)
* resolve conflict

* move `gate_product` to `GateController`

* BiLstm needs to use its own initializer when init

* resolve conflicts

* add some comments

* improve doc

* correct the description of GateController

* fix fmt

* add `LstmState`

* add test for state

* set batch 2 in bilstm test

* resolve conflict

* fix

* fix doc

* change the batch size back to 1

* change the batch size back to 1

* modify docstring; delete dead comment
2024-04-26 13:28:36 -05:00
Louis Fortier-Dubois 6ae3926006
New autodiff graph memory management strategy (#1698)
---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-04-26 12:25:53 -04:00
Nathaniel Simard 2f294c5092
Fix lstm batch size bug (#1695) 2024-04-26 08:54:12 -04:00
Guillaume Lagrange ce2429eb10
Refactor element type to be decoupled from runtime (#1693) 2024-04-26 08:53:55 -04:00
Dilshod Tadjibaev 67ec06d5d8
ONNX support for scalar unsqueeze (#1690)
* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze

* Add support for scalar unsqueeze

* Removed dead comment
2024-04-25 16:05:28 -05:00
Nathaniel Simard 599a20d586
Upgrade wgpu (#1692) 2024-04-25 16:32:50 -04:00
Dilshod Tadjibaev a1bd14c5ae
Reshape bug fix (#1684)
* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze
2024-04-24 19:31:53 -05:00
Nathaniel Simard 886a1de235
Refactor/burn compute (#1580) 2024-04-23 13:05:15 -04:00
Sylvain Benner c579686a8a
Move HandleContainer and Tensor Ops descriptions from burn-fusion to burn-tensor (#1654)
* Move HandlerContainer and Tensor Ops description to burn-tensor

Move HandleContainer and Tensor operations descriptions to burn-tensor crate.
Removed the FusionDevice and replaced it with a DeviceOps trait bound to Backend::Device.

For now added modules to burn-tensor are excluded from no-std as they rely on Arc.

* [burn-tensor] Flatten module hierarchy for tensor representation

+ Add new repr feature to cargo file.

* Remove prefix on dosctring

* [burn-fusion] Require default features of burn-tensor
2024-04-23 11:27:54 -04:00
Guillaume Lagrange e6b1b7a317
Add layer norm onnx op support (#1680) 2024-04-23 11:19:07 -04:00
Dilshod Tadjibaev 1718da5210
Fix reshape bug (support for opset version 1) (#1667)
* Make reshape op version 1

* Refactor per PR feedback
2024-04-22 17:52:25 -05:00
Nathaniel Simard 29fa2ee76c
Support linear 1d (#1682) 2024-04-22 18:39:09 -04:00
Alex Errant d62b344d5b
`Arc<EventStoreClient>` to `Rc<EventStoreClient>` (#1668) 2024-04-22 18:21:53 -04:00
신희제(Heejae Shin/joel.barish) 2a7b296a1b
Add sign ONNX op import support (#1663)
* Add sign ONNX op support

* Update SUPPORTED-ONNX-OPS.md
2024-04-22 09:10:50 -04:00
Louis Fortier-Dubois 2140d9b568
remove JIT subsequent RNG tests (#1652) 2024-04-21 09:48:11 -04:00
Dilshod Tadjibaev 1433284a0f
Fix bug 1645 (Unsqueeze OpSet 11) (#1661)
* Add unsqueeze opset 16 test

* Fix for unsqueeze opset 11

* Remove println statement
2024-04-19 14:17:44 -05:00
Guillaume Lagrange b65a487300
Fix transpose onnx op (permute) (#1657) 2024-04-19 09:34:03 -04:00
Nico Zweifel ee12aee2e7
fix: `window` -> `pub window` in `dataset/mod.rs` (#1658)
* update dataset/mod.rs

* Update mod.rs

* Update window.rs
2024-04-19 09:33:21 -04:00
Guillaume Lagrange 9fbcbed20f
Add where onnx op support (#1653)
* Add where onnx op support

* Add broadcasting support

* Remove broadcasting limitation comment

* Fix broadcasting in mask where

* Forgot to reflect changes in codegen test

* Fix clippy
2024-04-18 15:46:02 -04:00
Guillaume Lagrange 7705fd9c25
Add matmul ONNX op support (#1638)
* Mul onnx op already supported

* Add matmul onnx op checks and tests

* Add missing eq derives

* Change supscript symbol

* Remove dead code

* Add support for matmul broadcast

* No more broadcasting restrictions

* Add results comment for mm, mv and vm
2024-04-18 09:20:31 -04:00
Dilshod Tadjibaev 2a721a9d0c
Enable native sign operation for Candle backend (#1647)
* Enable native sign operation for Candle backend

* Use fixed revision
2024-04-17 09:07:56 -04:00
Guillaume Lagrange 424033283a
Add reduce max ONNX op support (#1636)
* Add reduce max onnx op support

* Fix comments on tensor rank 1 result
2024-04-17 08:26:46 -04:00
Nico Zweifel 5a3f345734
WindowDataset/windows function (#1553) 2024-04-17 07:51:53 -04:00
Guillaume Lagrange 35b36bbe62
Add shape ONNX op support (#1639)
* Add shape onnx op support

* Remove cast node from onnx graph

* Fix shape implementation

* Fix shape config error message

* Fix typo

* Fix clippy type complexity for generated code
2024-04-16 09:28:21 -04:00
Guillaume Lagrange 6d96e8d808
[ONNX] Add not op and extend cast support to tensors (#1634)
* Add not onnx op support

* Extend cast onnx support to tensors

* Fix clippy
2024-04-16 08:45:25 -04:00
Mathias Insley 7377bbe31c
Feat/remainder (#1597)
* Add remainder_scalar op to numeric trait and associated int/float functions

* Update burn-tch crate

* Update ndarray crate

* Update jit crate

* Update candle crate

* Update fusion crate

* Update autodiff crate

* Forgot float.rs for fusion

* Add burn-tensor tests

* Redirect to the pre-existing modulus op

* Fix sign

* Remove mut from burn-tch

* Use sign trick to make wgpu backend work

* Add more unit tests in to cover bases

* Naming fix for burn-fusion

* Update tests w/PyTorch link

* Use different WGSL instructions for remainder

* Redirect to remainder Operator instead of modulo

* Revert Modulo in instruction.rs
2024-04-16 08:35:20 -04:00
Mathias Insley 48c61ebb81
Docs/update contributor book (#1622)
* Update links to latest commit off main

* Some pedantry

* Update links and add jit

* Update instructions for burn-jit and wgpu

* Updated import section with more recent links

* Some grammar/typo/styling fixes

* Code added to burn-wgpu too
2024-04-16 08:33:59 -04:00
Guillaume Lagrange d5f20e2711
Add reduce mean ONNX op support (#1637)
* Add reduce mean onnx op support

* Fix comment
2024-04-16 07:59:35 -04:00
Dilshod Tadjibaev 340a12463a
Update SUPPORTED-ONNX-OPS.md (#1641) 2024-04-16 07:52:15 -04:00
Guillaume Lagrange 81a67b6a09
Add sin onnx op support (#1633) 2024-04-15 15:28:16 -04:00
Sylvain Benner e303e31c8b
Bump next version of Burn to 0.14.0 (#1618) 2024-04-12 17:14:45 -04:00
Guillaume Lagrange cf7b279e5e
Fix burn README symlink (#1617) 2024-04-12 16:00:47 -04:00
Guillaume Lagrange 9980db440d
Remove unused assets (#1616) 2024-04-12 15:48:16 -04:00
Guillaume Lagrange 264c167c11
Update licenses symlinks (#1613) 2024-04-12 14:43:58 -04:00
Nathaniel Simard ff844b1667
Fix candle backend sync (#1579)
* Fix candle backend sync

* tch mps sync

* clippy

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-04-12 12:15:50 -04:00
Aasheesh Singh fb1da53a38
support for rotary positional encoding to transformer modules. (#1604)
* add rotary positional encoding to transformer modules.

* fix f64 error

* use num_traits

* add panic condition
2024-04-12 11:45:49 -04:00
Louis Fortier-Dubois 23210f05f2
JIT: Autotune matmul tiling 2d unroll (#1601)
* autotune tiling 2d unroll

* clippy

* forgotten important stuff
2024-04-12 10:15:21 -04:00
Nathaniel Simard 07a61a1cec
Fix autodiff memory management graph cleaning (#1602) 2024-04-11 16:21:00 -04:00
Guillaume Lagrange 0cbe9a927d
Add learner training report summary (#1591)
* Add training report summary

* Fix LossMetric batch size state

* Add NumericEntry de/serialize

* Fix clippy suggestion

* Compact recorder does not use compression (anymore)

* Add learner summary expected results tests

* Add summary to learner builder and automatically display in fit

- Add LearnerSummaryConfig
- Keep track of summary metrics names
- Add model field when displaying from learner.fit()
2024-04-11 12:32:25 -04:00