Commit Graph

1395 Commits

Author SHA1 Message Date
Nathaniel Simard d3fbdeaa48
Fix CI (#2268) 2024-09-10 12:13:48 -04:00
Genna Wingert 17050db57e
Migrate cubecl macro (#2266) 2024-09-10 11:31:02 -04:00
Sylvain Benner e9ec35f764
[CI] Fix llvmpipe, lavapipe install for valgrind and vulnerabilities (#2264)
* [CI] Fix llvmpipe, lavapipe install for valgrind and vulnerabilities

* Test github-actions v1.4.0

* Use v1 tag of github-actions
2024-09-09 18:29:01 -04:00
Guillaume Lagrange eb899db16c
Add ops w/ default implementation for `QTensorOps` (#2125)
* Add q_* ops to match float ops

* Refactor q_* ops w/ dequant_op_quant macro

* Comparison ops are already implemented by default to compare dequantized values

* Add default arg min/max implementation and fix tch implementation

* Avoid division by zero scale

* Add default q_gather implementation (tch does not support on quantized tensor)

* Add warning instead for tch quantize_dynamic

* Call chunk backend implementation

* Add QFloat check for q_ ops

* Add tch q_min/max_dim_with_indices

* Add q_ ops tests

* Clippy fix

* Remove dead code/comments

* Fix quantization tests precision

* Set higher tolerance for ndarray backend

* Remove comment
2024-09-09 12:21:47 -04:00
Joshua Ferguson 9e9451bb60
simplify scope tracking in burn-import (#2207)
* simplify scope tracking in burn-import

* removed unecessary return statement
2024-09-09 12:19:26 -04:00
Asher Jingkong Chen ccb5b2214e
Fix burn-jit conv2d excessive loop unrolling (#2263)
* Related to issue #2260
2024-09-09 11:16:13 -04:00
Nathaniel Simard 94cd8a2556
Perf/slice (#2252) 2024-09-09 11:08:39 -04:00
Mehmet Ali Anil 3d91b40005
fixed path (#2262) 2024-09-09 10:29:27 -04:00
dependabot[bot] a9f941d403
Bump peter-evans/create-pull-request from 6 to 7 (#2245)
Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 6 to 7.
- [Release notes](https://github.com/peter-evans/create-pull-request/releases)
- [Commits](https://github.com/peter-evans/create-pull-request/compare/v6...v7)

---
updated-dependencies:
- dependency-name: peter-evans/create-pull-request
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-06 16:22:13 -04:00
Paul Wagener 0f191e67aa
Fix panic messages being invisible in tui mode (#2226)
* Fix panic messages being invisible in tui mode

Currently when a panic happens the message gets printed to the alternate screen which gets erased after the terminal is reset to raw mode in the TuiMetricsRenderer drop code.

That leaves users unable to see the panic message (issue #2062).

This commit changes TuiMetricsRenderer to reset the terminal first during a panic and then running the panic handler.

* Use PanicInfo to support Rust version < 1.82
2024-09-06 16:22:00 -04:00
Nathaniel Simard a567c6e888
Fusion mix precision (#2247) 2024-09-05 10:53:26 -04:00
Asher Jingkong Chen fc311323d9
[burn-autodiff] Fix abs NaN when output is 0 (#2249) 2024-09-05 09:03:24 -04:00
Sylvain Benner 3dfd99c18b
Set tracel-xtask version to 1.0.x (#2250) 2024-09-05 00:24:33 -04:00
Sylvain Benner 6787e778bc
Update CI workflow for last version of setup-linux action (#2248) 2024-09-04 16:07:48 -04:00
github-actions[bot] e1e6665365
Combined PRs (#2241)
* Bump dashmap from 5.5.3 to 6.0.1

Bumps [dashmap](https://github.com/xacrimon/dashmap) from 5.5.3 to 6.0.1.
- [Release notes](https://github.com/xacrimon/dashmap/releases)
- [Commits](https://github.com/xacrimon/dashmap/compare/v.5.5.3...v6.0.1)

---
updated-dependencies:
- dependency-name: dashmap
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump bytemuck from 1.17.0 to 1.17.1

Bumps [bytemuck](https://github.com/Lokathor/bytemuck) from 1.17.0 to 1.17.1.
- [Changelog](https://github.com/Lokathor/bytemuck/blob/main/changelog.md)
- [Commits](https://github.com/Lokathor/bytemuck/compare/v1.17.0...v1.17.1)

---
updated-dependencies:
- dependency-name: bytemuck
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serde from 1.0.208 to 1.0.209

Bumps [serde](https://github.com/serde-rs/serde) from 1.0.208 to 1.0.209.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.208...v1.0.209)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump syn from 2.0.75 to 2.0.77

Bumps [syn](https://github.com/dtolnay/syn) from 2.0.75 to 2.0.77.
- [Release notes](https://github.com/dtolnay/syn/releases)
- [Commits](https://github.com/dtolnay/syn/compare/2.0.75...2.0.77)

---
updated-dependencies:
- dependency-name: syn
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tokio from 1.39.3 to 1.40.0

Bumps [tokio](https://github.com/tokio-rs/tokio) from 1.39.3 to 1.40.0.
- [Release notes](https://github.com/tokio-rs/tokio/releases)
- [Commits](https://github.com/tokio-rs/tokio/compare/tokio-1.39.3...tokio-1.40.0)

---
updated-dependencies:
- dependency-name: tokio
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-09-04 09:30:55 -04:00
Nathaniel Simard 756931211c
update cubecl (#2244) 2024-09-03 17:57:39 -04:00
Bjorn Beishline 474fa113d6
Fixed raspberry pi pico example not compiling (#2220)
The import for the model didn't get renamed properly when the name changed from onnx-inference-rp2040 to raspberry-pi-pico.
2024-09-03 11:46:11 -04:00
Adrian Müller 6b51b73a5f
Fix ONNX where op for scalar inputs (#2218)
* Fix ONNX where op dim_inference for scalar inputs

* Rewrite ONNX Where codegen to support scalars

* ONNX Where: Add tests for all_scalar inputs

---------

Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
2024-09-03 11:17:18 -04:00
Guillaume Lagrange 59d41bd4b2
Remove copy restriction for const generic modules (#2222) 2024-09-03 09:39:12 -04:00
Guillaume Lagrange cc214d366c
Nonzero should return an empty vec for zero tensors (#2212)
* Nonzero should return an empty vec for zero tensors

* Add nonzero empty test

* Add missing import

---------

Co-authored-by: Nathaniel Simard <nathaniel.simard.42@gmail.com>
2024-09-03 09:00:58 -04:00
Nathaniel Simard 96a23408d2
Chore/update cubecl (#2235) 2024-09-01 17:05:13 -04:00
Paul Wagener c1b61033f4
Fix compile for dataset crate with vision feature (#2228)
This fixes the compile error when burn is compiled with only the `dataset` and `vision` feature enabled

burn = { default-features = false, features = ["dataset", "vision"] }
2024-09-01 17:03:37 -04:00
tiruka ce2a50880b
fixed the debugger settings doc (#2223) 2024-09-01 16:37:56 -04:00
Guillaume Lagrange 09a15e7e15
Avoid 0 denominator in interpolate frac (#2224) 2024-09-01 16:37:32 -04:00
Paul Wagener 23622d765d
Don't panic when the progress is > 1.0 (#2229)
Ratatui asserts that gauges don't have a progress greater than 1.0
This can happen if a dataset reports a lower len() than it actually provides.

This change prevents a panic when the `Progress::items_processed` is greater than the `Progress::items_total`
2024-09-01 16:33:25 -04:00
Dilshod Tadjibaev 44030ead17
Create CITATION.cff (#2231) 2024-09-01 16:32:07 -04:00
王翼翔 66ee3bb3bc
Update huber.rs (#2232) 2024-09-01 16:31:07 -04:00
Sylvain Benner e8828afb29
Automatic minimum rust version in README (#2227) 2024-08-30 19:30:54 -04:00
Sylvain Benner 9c5cb511fa
Rename CI workflow back to test.yml (#2225)
* Rename CI workflow back to test.yml

* Static job names in test workflow

This allows to have static status check names

* Remove need for a cache-version matrix variable
2024-08-30 19:09:55 -04:00
Nathaniel Simard 0dbb7f7e91
Chore: Update cubecl (#2219) 2024-08-30 15:28:00 -04:00
Guillaume Lagrange a9abd8f746
Add missing output padding to conv transpose ONNX (#2216)
* Add output_padding support for ONNX ConvTranspose

* Add missing codegen

* Fix output padding codegen test
2024-08-29 14:07:00 -04:00
Dilshod Tadjibaev 28c2d4e3cd
Update SUPPORTED-ONNX-OPS.md (#2217) 2024-08-29 14:06:42 -04:00
Adrian Müller e8ea9e27c2
Improve ONNX import tensor shape tracking (#2213)
- Calculate result of broadcasting in dim_inference
- keep Shape info when converting from Argument to TensorType
- Remove a few sources of Dim = 0 Tensors, create Scalars instead
- Clean up dim_inference a bit
2024-08-29 14:06:30 -04:00
Adrian Müller 2f4c5ac0a1
Feat: Allow onnx-import expand op with non-const shapes (#2189)
* Feat: Allow onnx-import expand op with non-const shapes

* Generalize ONNX Expand across IntElem
2024-08-29 13:15:44 -04:00
Guillaume Lagrange 7baa33bdaa
Fix target convert in batcher and align guide imports (#2215)
* Fix target convert in batcher

* Align hidden code in training and update loss to use config

* Align imports with example

* Remove unused import and fix guide->crate
2024-08-29 08:58:51 -04:00
Sylvain Benner a88c69af4a
Refactor xtask to use tracel-xtask and refactor CI workflow (#2063)
* Migrate to xtask-common crate

* Fix example crate name for simple-regression

* Refactor CI workflows

* Flatten linux workflows

* Install grcov and typos from binaries

Although xtask-common support auto-installation of these tools via cargo
it is a lot faster to install them via the distributed binaries

* [CI] Update Rust caches on failure

* [CI] Add shell bash to jobs steps

* [CI] Try cache all crates

* Fix no-std tests not executing

* [CI] Add CARGO_INCREMENTAL 0

* Exclude tch and cuda from tests and merge crates and examples steps

* Fix some typos found with typos cli

* Add Windows and MacOS jobs

* Only test no-std with default rust target

* Fix syntax in composite action setup-windows

* Enable incremental build

* Upate cargo alias for xtask

* Bump to github action checkout v4

* Revert to tch 0.15 and disable WGPU on windows

* Fix color in output

* Add Test command

* Test long output errorring

* Build and test workspace before additional builds and tests

* Disable wgpu tests on windows

* Remove tests- prefix in CI workflow jobs name

* Add Checks command

* Rename ci workflow jobs

* Execute windows and macos CI tests on rust stable only

* Rename integration test files with a test_ prefix

* Fix format

* Don't auto-correct "arange" with typos

* Fix typos in code

* Merge unit and integration tests steps

* Fix macos tests

* Fix coverage step

* Name publish-crate workflow

* Fix bad cache name for macos

* Reorganize commands and get rid of the ci command

* Fix dispatch to customized commands for Burn

* Update to last version of tracel-xtask

* Remove unnecessary shell bash in ci workflow

* Update cargo.lock

* Fix format

* Bump tracel-xtask

* Simplify dispatch of base commands using updated macro

* Update to last version of tracel-xtask

* Adapt legacy run_checks script with new xtask commands

* Run xtask in debug for faster compilation time

* Ditch build step in ci and enable coverage for stable linux only

* Freeze tracel-xtask to specific commit rev

* Update cargo.lock

* Update Step 6 of CONTRIBUTING guidelines about run-checks script

* Remove unneeded CI and CD paragraphgs in CONRIBUTING.md

* Change cache version

* Fix typos

* Use centralized actions and workflows

* Update to last version of tracel-xtask

* Update CONTRIBUTING file to mention integration tests

* Add custom build for thumbv6m-none-eabi

* Ignore onnx files for typos check

* Fix action and workflow paths in github workflows

* Fix custom builds on MacOS

* Bump tracel-xtask crate to last version

* Update Cargo.lock

* Update publish workflow to use reusable workflow in tracel repo

* Add --ci flag for build and test commands
2024-08-28 15:57:13 -04:00
Guillaume Lagrange 40d321cc0d
Fix tensor data elem type conversion in book (#2211) 2024-08-28 10:55:10 -04:00
AlteredOxide 0292967000
Feature/codegen gather indices greater than rank 1 (#2199)
* implemented muli-dim index for GatherNode

The `NodeCodegen` impl for `GatherNode` now performs gather in complete
accordance with the ONNX Gather spec.
- a `gather` function was added to the gather.rs file
- `gather()` is now called within the codegen instead of `tensor.select()`
- a test with two test cases have been added
    - test axes 0 and 1
    - both use 2D index tensors

* add gather_onnx to numeric api

Added int and float implementations of gather to the burn-tensor numeric
api:
- named the methods `gather_onnx` to not be confused with the current
  `gather`
- these implementations follow the `Gather` ONNX spec

Updated the gather*.py variants and their onnx outputs

* modified files didn't end up in last commit

* tests passing for onnx gather

The implementation of gather for the ONNX `Gather` spec is tentatively
complete:
- py test models are updated
- onnx_tests are modified and passing: `gather`, `gather_scalar`, and
  `gather_shape`
- node/gather tests are passing

NOTE: The two additional tests in crates/burn-import/src/burn/node/gather.rs that test
the actual functionality of gather are likely to be deleted, since they
are redundant to the tests in
crates/burn-import/onnx-tests/tests/onnx_tests.rs.

* inlined onnx gather within codegen

* rm gather_onnx from public api; rm unnecessary tests

* add comments to gather py models

* some codegen changes; formatting to appease run-checks

- Some necessary changes and improvements to the codegen inlined code
  after translating from public api (removed in previous commit).
- Changed some formatting that run-checks complained about.

* simplify gather codegen; include 1d and 2d onnx tests

Modified the `Gather` codegen per requested changes:
- combined match statements on index
- remove use of `alloc::vec::Vec`
- use map -> collect instead of procedural
- include a 1d index gather onnx test
- remove superflous tests

* delete unused gather.onnx
2024-08-28 07:51:19 -04:00
mepatrick73 795201dcfc
Select kernel from CPA to CubeCL (#2168)
---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-08-27 15:17:58 -04:00
syl20bnr a600a7b54e Bump Burn version in the Burn Book 2024-08-27 15:13:40 -04:00
syl20bnr 8e78106680 Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
github-actions[bot] 0c77921aa7
Combined PRs (#2205)
* bump clap from 4.5.15 to 4.5.16

---
updated-dependencies:
- dependency-name: clap
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump quote from 1.0.36 to 1.0.37

Bumps [quote](https://github.com/dtolnay/quote) from 1.0.36 to 1.0.37.
- [Release notes](https://github.com/dtolnay/quote/releases)
- [Commits](https://github.com/dtolnay/quote/compare/1.0.36...1.0.37)

---
updated-dependencies:
- dependency-name: quote
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serde_json from 1.0.124 to 1.0.127

Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.124 to 1.0.127.
- [Release notes](https://github.com/serde-rs/json/releases)
- [Commits](https://github.com/serde-rs/json/compare/v1.0.124...1.0.127)

---
updated-dependencies:
- dependency-name: serde_json
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump gix-tempfile from 14.0.1 to 14.0.2

Bumps [gix-tempfile](https://github.com/Byron/gitoxide) from 14.0.1 to 14.0.2.
- [Release notes](https://github.com/Byron/gitoxide/releases)
- [Changelog](https://github.com/Byron/gitoxide/blob/main/CHANGELOG.md)
- [Commits](https://github.com/Byron/gitoxide/compare/gix-tempfile-v14.0.1...gix-tempfile-v14.0.2)

---
updated-dependencies:
- dependency-name: gix-tempfile
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump zip from 2.1.6 to 2.2.0

Bumps [zip](https://github.com/zip-rs/zip2) from 2.1.6 to 2.2.0.
- [Release notes](https://github.com/zip-rs/zip2/releases)
- [Changelog](https://github.com/zip-rs/zip2/blob/master/CHANGELOG.md)
- [Commits](https://github.com/zip-rs/zip2/compare/v2.1.6...v2.2.0)

---
updated-dependencies:
- dependency-name: zip
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-08-27 15:03:57 -04:00
nathaniel 4e99ddecdf Fix burn-import version to onnx-ir 2024-08-27 12:58:08 -04:00
nathaniel 0edfc38857 Remove burn-compute from github action 2024-08-27 12:39:01 -04:00
nathaniel 7eb3a7b27a Add cuda in CI publish 2024-08-27 12:36:19 -04:00
nathaniel 9881ca6359 Update cubecl version 2024-08-27 12:05:07 -04:00
Nathaniel Simard 79cd3d5d21
Fix gather unchecked kernel (#2206) 2024-08-26 12:23:02 -04:00
Nathaniel Simard 978ac6c4ec
Chore: Update to newer cubecl version (#2181) 2024-08-25 15:33:16 -04:00
dependabot[bot] 9adf493305
Bump syn from 2.0.74 to 2.0.75 (#2173)
Bumps [syn](https://github.com/dtolnay/syn) from 2.0.74 to 2.0.75.
- [Release notes](https://github.com/dtolnay/syn/releases)
- [Commits](https://github.com/dtolnay/syn/compare/2.0.74...2.0.75)

---
updated-dependencies:
- dependency-name: syn
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-08-25 14:48:09 -04:00
mepatrick73 0beec0e39e
Scatter kernel from cpa to cubecl (#2169) 2024-08-25 13:47:16 -04:00