Commit Graph

1410 Commits

Author SHA1 Message Date
Guillaume Lagrange aa79e36a8d
Add more quantization support for burn-jit (#2275)
* Add cubecl quantization kernels and QTensorOps for burn-jit

* Fix typo

* Fix output vec factor

* Fix output dtype size_of

* Remove unused code in dequantize test

* Fix dequantize vectorization

* Handle tensors when number of elems is not a multiple of 4

* Support quantize for tensors with less than 4 elems (no vectorization)

* Fix equal 0 test

* Add quantize/dequantize tests

* Add q_to_device

* Refactor kernels for latest cubecl

* intermediate i32 cast

* Fix size_of output type

* Use strict=false to ignore floating point precision issues with qparams equality

* Only check that lhs & rhs strategies match (but not strict on qparams values)

* Use assert_approx_eq on dequant values

* Reduce precision for flaky test

* Remove todo comment

* Add comment for cast to unsigned

* More comment

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-09-17 10:08:20 -04:00
Nathaniel Simard 834005eadb
Update rev (#2283) 2024-09-17 09:25:44 -04:00
tiruka c0656b5f9f
modify unresolved import `regression` (#2285) 2024-09-17 08:33:09 -04:00
Asher Jingkong Chen 7ac5deebe2
Refactor burn-tensor: Split conv backward ops to allow conditional gradient computation (#2278) 2024-09-16 10:15:27 -04:00
Guillaume Lagrange 81ec64a929
Add ResNet benchmark (#1534) 2024-09-16 09:57:15 -04:00
github-actions[bot] 5631afb3a0
Combined PRs (#2282)
* Bump arboard from 3.4.0 to 3.4.1

Bumps [arboard](https://github.com/1Password/arboard) from 3.4.0 to 3.4.1.
- [Release notes](https://github.com/1Password/arboard/releases)
- [Changelog](https://github.com/1Password/arboard/blob/master/CHANGELOG.md)
- [Commits](https://github.com/1Password/arboard/compare/v3.4.0...v3.4.1)

---
updated-dependencies:
- dependency-name: arboard
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump dashmap from 6.0.1 to 6.1.0

Bumps [dashmap](https://github.com/xacrimon/dashmap) from 6.0.1 to 6.1.0.
- [Release notes](https://github.com/xacrimon/dashmap/releases)
- [Commits](https://github.com/xacrimon/dashmap/compare/v6.0.1...v6.1.0)

---
updated-dependencies:
- dependency-name: dashmap
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump pretty_assertions from 1.4.0 to 1.4.1

Bumps [pretty_assertions](https://github.com/rust-pretty-assertions/rust-pretty-assertions) from 1.4.0 to 1.4.1.
- [Release notes](https://github.com/rust-pretty-assertions/rust-pretty-assertions/releases)
- [Changelog](https://github.com/rust-pretty-assertions/rust-pretty-assertions/blob/main/CHANGELOG.md)
- [Commits](https://github.com/rust-pretty-assertions/rust-pretty-assertions/compare/v1.4.0...v1.4.1)

---
updated-dependencies:
- dependency-name: pretty_assertions
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-09-16 09:20:28 -04:00
nathaniel 395d84ce71 Fix comments 2024-09-16 09:10:34 -04:00
Periwink a1d2b13e3e
add comments to burn fusion (#2130) 2024-09-16 09:02:12 -04:00
Guillaume Lagrange 6f0e61aa4f
Change ndarray mask_where implementation to correctly deal with NaNs (#2272)
* Change ndarray mask_where implementation to correctly deal with NaNs

* Add test
2024-09-13 15:16:39 -04:00
dependabot[bot] 2fbad48f64
Bump serde from 1.0.209 to 1.0.210 (#2255)
Bumps [serde](https://github.com/serde-rs/serde) from 1.0.209 to 1.0.210.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.209...v1.0.210)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-13 14:59:50 -04:00
dependabot[bot] bf207f0b1b
Bump serde_json from 1.0.127 to 1.0.128 (#2254)
Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.127 to 1.0.128.
- [Release notes](https://github.com/serde-rs/json/releases)
- [Commits](https://github.com/serde-rs/json/compare/1.0.127...1.0.128)

---
updated-dependencies:
- dependency-name: serde_json
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-13 14:18:59 -04:00
dependabot[bot] 8a19b06d3b
Bump flate2 from 1.0.32 to 1.0.33 (#2256)
Bumps [flate2](https://github.com/rust-lang/flate2-rs) from 1.0.32 to 1.0.33.
- [Release notes](https://github.com/rust-lang/flate2-rs/releases)
- [Changelog](https://github.com/rust-lang/flate2-rs/blob/main/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/flate2-rs/compare/1.0.32...1.0.33)

---
updated-dependencies:
- dependency-name: flate2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-13 13:42:43 -04:00
dependabot[bot] 8fb3de9f8d
Bump clap from 4.5.16 to 4.5.17 (#2258)
Bumps [clap](https://github.com/clap-rs/clap) from 4.5.16 to 4.5.17.
- [Release notes](https://github.com/clap-rs/clap/releases)
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md)
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.16...clap_complete-v4.5.17)

---
updated-dependencies:
- dependency-name: clap
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-13 13:34:02 -04:00
dependabot[bot] e4f60b837b
Bump bytemuck from 1.17.1 to 1.18.0 (#2257)
Bumps [bytemuck](https://github.com/Lokathor/bytemuck) from 1.17.1 to 1.18.0.
- [Changelog](https://github.com/Lokathor/bytemuck/blob/main/changelog.md)
- [Commits](https://github.com/Lokathor/bytemuck/compare/v1.17.1...v1.18.0)

---
updated-dependencies:
- dependency-name: bytemuck
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-13 13:05:58 -04:00
Nathaniel Simard 58ce502498
Fix (#2269) 2024-09-10 13:36:00 -04:00
Nathaniel Simard d3fbdeaa48
Fix CI (#2268) 2024-09-10 12:13:48 -04:00
Genna Wingert 17050db57e
Migrate cubecl macro (#2266) 2024-09-10 11:31:02 -04:00
Sylvain Benner e9ec35f764
[CI] Fix llvmpipe, lavapipe install for valgrind and vulnerabilities (#2264)
* [CI] Fix llvmpipe, lavapipe install for valgrind and vulnerabilities

* Test github-actions v1.4.0

* Use v1 tag of github-actions
2024-09-09 18:29:01 -04:00
Guillaume Lagrange eb899db16c
Add ops w/ default implementation for `QTensorOps` (#2125)
* Add q_* ops to match float ops

* Refactor q_* ops w/ dequant_op_quant macro

* Comparison ops are already implemented by default to compare dequantized values

* Add default arg min/max implementation and fix tch implementation

* Avoid division by zero scale

* Add default q_gather implementation (tch does not support on quantized tensor)

* Add warning instead for tch quantize_dynamic

* Call chunk backend implementation

* Add QFloat check for q_ ops

* Add tch q_min/max_dim_with_indices

* Add q_ ops tests

* Clippy fix

* Remove dead code/comments

* Fix quantization tests precision

* Set higher tolerance for ndarray backend

* Remove comment
2024-09-09 12:21:47 -04:00
Joshua Ferguson 9e9451bb60
simplify scope tracking in burn-import (#2207)
* simplify scope tracking in burn-import

* removed unecessary return statement
2024-09-09 12:19:26 -04:00
Asher Jingkong Chen ccb5b2214e
Fix burn-jit conv2d excessive loop unrolling (#2263)
* Related to issue #2260
2024-09-09 11:16:13 -04:00
Nathaniel Simard 94cd8a2556
Perf/slice (#2252) 2024-09-09 11:08:39 -04:00
Mehmet Ali Anil 3d91b40005
fixed path (#2262) 2024-09-09 10:29:27 -04:00
dependabot[bot] a9f941d403
Bump peter-evans/create-pull-request from 6 to 7 (#2245)
Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 6 to 7.
- [Release notes](https://github.com/peter-evans/create-pull-request/releases)
- [Commits](https://github.com/peter-evans/create-pull-request/compare/v6...v7)

---
updated-dependencies:
- dependency-name: peter-evans/create-pull-request
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-06 16:22:13 -04:00
Paul Wagener 0f191e67aa
Fix panic messages being invisible in tui mode (#2226)
* Fix panic messages being invisible in tui mode

Currently when a panic happens the message gets printed to the alternate screen which gets erased after the terminal is reset to raw mode in the TuiMetricsRenderer drop code.

That leaves users unable to see the panic message (issue #2062).

This commit changes TuiMetricsRenderer to reset the terminal first during a panic and then running the panic handler.

* Use PanicInfo to support Rust version < 1.82
2024-09-06 16:22:00 -04:00
Nathaniel Simard a567c6e888
Fusion mix precision (#2247) 2024-09-05 10:53:26 -04:00
Asher Jingkong Chen fc311323d9
[burn-autodiff] Fix abs NaN when output is 0 (#2249) 2024-09-05 09:03:24 -04:00
Sylvain Benner 3dfd99c18b
Set tracel-xtask version to 1.0.x (#2250) 2024-09-05 00:24:33 -04:00
Sylvain Benner 6787e778bc
Update CI workflow for last version of setup-linux action (#2248) 2024-09-04 16:07:48 -04:00
github-actions[bot] e1e6665365
Combined PRs (#2241)
* Bump dashmap from 5.5.3 to 6.0.1

Bumps [dashmap](https://github.com/xacrimon/dashmap) from 5.5.3 to 6.0.1.
- [Release notes](https://github.com/xacrimon/dashmap/releases)
- [Commits](https://github.com/xacrimon/dashmap/compare/v.5.5.3...v6.0.1)

---
updated-dependencies:
- dependency-name: dashmap
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump bytemuck from 1.17.0 to 1.17.1

Bumps [bytemuck](https://github.com/Lokathor/bytemuck) from 1.17.0 to 1.17.1.
- [Changelog](https://github.com/Lokathor/bytemuck/blob/main/changelog.md)
- [Commits](https://github.com/Lokathor/bytemuck/compare/v1.17.0...v1.17.1)

---
updated-dependencies:
- dependency-name: bytemuck
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serde from 1.0.208 to 1.0.209

Bumps [serde](https://github.com/serde-rs/serde) from 1.0.208 to 1.0.209.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.208...v1.0.209)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump syn from 2.0.75 to 2.0.77

Bumps [syn](https://github.com/dtolnay/syn) from 2.0.75 to 2.0.77.
- [Release notes](https://github.com/dtolnay/syn/releases)
- [Commits](https://github.com/dtolnay/syn/compare/2.0.75...2.0.77)

---
updated-dependencies:
- dependency-name: syn
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tokio from 1.39.3 to 1.40.0

Bumps [tokio](https://github.com/tokio-rs/tokio) from 1.39.3 to 1.40.0.
- [Release notes](https://github.com/tokio-rs/tokio/releases)
- [Commits](https://github.com/tokio-rs/tokio/compare/tokio-1.39.3...tokio-1.40.0)

---
updated-dependencies:
- dependency-name: tokio
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-09-04 09:30:55 -04:00
Nathaniel Simard 756931211c
update cubecl (#2244) 2024-09-03 17:57:39 -04:00
Bjorn Beishline 474fa113d6
Fixed raspberry pi pico example not compiling (#2220)
The import for the model didn't get renamed properly when the name changed from onnx-inference-rp2040 to raspberry-pi-pico.
2024-09-03 11:46:11 -04:00
Adrian Müller 6b51b73a5f
Fix ONNX where op for scalar inputs (#2218)
* Fix ONNX where op dim_inference for scalar inputs

* Rewrite ONNX Where codegen to support scalars

* ONNX Where: Add tests for all_scalar inputs

---------

Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
2024-09-03 11:17:18 -04:00
Guillaume Lagrange 59d41bd4b2
Remove copy restriction for const generic modules (#2222) 2024-09-03 09:39:12 -04:00
Guillaume Lagrange cc214d366c
Nonzero should return an empty vec for zero tensors (#2212)
* Nonzero should return an empty vec for zero tensors

* Add nonzero empty test

* Add missing import

---------

Co-authored-by: Nathaniel Simard <nathaniel.simard.42@gmail.com>
2024-09-03 09:00:58 -04:00
Nathaniel Simard 96a23408d2
Chore/update cubecl (#2235) 2024-09-01 17:05:13 -04:00
Paul Wagener c1b61033f4
Fix compile for dataset crate with vision feature (#2228)
This fixes the compile error when burn is compiled with only the `dataset` and `vision` feature enabled

burn = { default-features = false, features = ["dataset", "vision"] }
2024-09-01 17:03:37 -04:00
tiruka ce2a50880b
fixed the debugger settings doc (#2223) 2024-09-01 16:37:56 -04:00
Guillaume Lagrange 09a15e7e15
Avoid 0 denominator in interpolate frac (#2224) 2024-09-01 16:37:32 -04:00
Paul Wagener 23622d765d
Don't panic when the progress is > 1.0 (#2229)
Ratatui asserts that gauges don't have a progress greater than 1.0
This can happen if a dataset reports a lower len() than it actually provides.

This change prevents a panic when the `Progress::items_processed` is greater than the `Progress::items_total`
2024-09-01 16:33:25 -04:00
Dilshod Tadjibaev 44030ead17
Create CITATION.cff (#2231) 2024-09-01 16:32:07 -04:00
王翼翔 66ee3bb3bc
Update huber.rs (#2232) 2024-09-01 16:31:07 -04:00
Sylvain Benner e8828afb29
Automatic minimum rust version in README (#2227) 2024-08-30 19:30:54 -04:00
Sylvain Benner 9c5cb511fa
Rename CI workflow back to test.yml (#2225)
* Rename CI workflow back to test.yml

* Static job names in test workflow

This allows to have static status check names

* Remove need for a cache-version matrix variable
2024-08-30 19:09:55 -04:00
Nathaniel Simard 0dbb7f7e91
Chore: Update cubecl (#2219) 2024-08-30 15:28:00 -04:00
Guillaume Lagrange a9abd8f746
Add missing output padding to conv transpose ONNX (#2216)
* Add output_padding support for ONNX ConvTranspose

* Add missing codegen

* Fix output padding codegen test
2024-08-29 14:07:00 -04:00
Dilshod Tadjibaev 28c2d4e3cd
Update SUPPORTED-ONNX-OPS.md (#2217) 2024-08-29 14:06:42 -04:00
Adrian Müller e8ea9e27c2
Improve ONNX import tensor shape tracking (#2213)
- Calculate result of broadcasting in dim_inference
- keep Shape info when converting from Argument to TensorType
- Remove a few sources of Dim = 0 Tensors, create Scalars instead
- Clean up dim_inference a bit
2024-08-29 14:06:30 -04:00
Adrian Müller 2f4c5ac0a1
Feat: Allow onnx-import expand op with non-const shapes (#2189)
* Feat: Allow onnx-import expand op with non-const shapes

* Generalize ONNX Expand across IntElem
2024-08-29 13:15:44 -04:00
Guillaume Lagrange 7baa33bdaa
Fix target convert in batcher and align guide imports (#2215)
* Fix target convert in batcher

* Align hidden code in training and update loss to use config

* Align imports with example

* Remove unused import and fix guide->crate
2024-08-29 08:58:51 -04:00