Add cuda gpu example + doc (#91)

This commit is contained in:
Nathaniel Simard 2022-11-09 21:32:51 -05:00 committed by GitHub
parent 947ed00301
commit 3da122db09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 8 deletions

View File

@ -44,14 +44,6 @@ Also, this may be a good idea to checkout the main [components](#components) to
For now there is only one example, but more to come 💪..
The `mnist` example can be run like so:
```console
$ git clone https://github.com/burn-rs/burn.git
$ cd burn
$ cargo run --example mnist
```
#### MNIST
The [MNIST](https://github.com/burn-rs/burn/blob/main/examples/mnist) example is not just of small script that shows you how to train a basic model, but it's a quick one showing you how to:
@ -60,6 +52,17 @@ The [MNIST](https://github.com/burn-rs/burn/blob/main/examples/mnist) example is
* Create the data pipeline from a raw dataset to a batched multi-threaded fast DataLoader.
* Configure a [learner](#learner) to display and log metrics as well as to keep training checkpoints.
The example can be run like so:
```console
$ git clone https://github.com/burn-rs/burn.git
$ cd burn
$ export TORCH_CUDA_VERSION=cu113 # Set the cuda version
$ # Use the --release flag to really speed up training.
$ cargo run --example mnist --release # CPU NdArray Backend
$ cargo run --example mnist_cuda_gpu --release # GPU Tch Backend
```
### Components
Knowing the main components will be of great help when starting playing with `burn`.

View File

@ -0,0 +1,9 @@
use mnist::training;
fn main() {
use burn::tensor::backend::{TchADBackend, TchDevice};
let device = TchDevice::Cuda(0);
training::run::<TchADBackend<burn::tensor::f16>>(device);
println!("Done.");
}