15a79c1b4a
* Bump serde from 1.0.213 to 1.0.214 Bumps [serde](https://github.com/serde-rs/serde) from 1.0.213 to 1.0.214. - [Release notes](https://github.com/serde-rs/serde/releases) - [Commits](https://github.com/serde-rs/serde/compare/v1.0.213...v1.0.214) --- updated-dependencies: - dependency-name: serde dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> * Bump tar from 0.4.42 to 0.4.43 Bumps [tar](https://github.com/alexcrichton/tar-rs) from 0.4.42 to 0.4.43. - [Commits](https://github.com/alexcrichton/tar-rs/compare/0.4.42...0.4.43) --- updated-dependencies: - dependency-name: tar dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> * Bump thiserror from 1.0.65 to 1.0.67 Bumps [thiserror](https://github.com/dtolnay/thiserror) from 1.0.65 to 1.0.67. - [Release notes](https://github.com/dtolnay/thiserror/releases) - [Commits](https://github.com/dtolnay/thiserror/compare/1.0.65...1.0.67) --- updated-dependencies: - dependency-name: thiserror dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> |
||
---|---|---|
.. | ||
examples | ||
src | ||
.gitignore | ||
Cargo.toml | ||
README.md |
README.md
Training on a Custom Image Dataset
In this example, a simple CNN model is trained from scratch on the
CIFAR-10 dataset by leveraging the
ImageFolderDataset
struct to retrieve images from a folder structure on disk.
Since the original source is in binary format, the data is downloaded from a
fastai mirror in a
folder structure with .png
images.
cifar10
├── labels.txt
├── test
│ ├── airplane
│ ├── automobile
│ ├── bird
│ ├── cat
│ ├── deer
│ ├── dog
│ ├── frog
│ ├── horse
│ ├── ship
│ └── truck
└── train
├── airplane
├── automobile
├── bird
├── cat
├── deer
├── dog
├── frog
├── horse
├── ship
└── truck
To load the training and test dataset splits, it is as simple as providing the root path to both folders
let train_ds = ImageFolderDataset::new_classification("/path/to/cifar10/train").unwrap();
let test_ds = ImageFolderDataset::new_classification("/path/to/cifar10/test").unwrap();
as is done in CIFAR10Loader
for this example.
Example Usage
The CNN model and training recipe used in this example are fairly simple since the objective is to demonstrate how to load a custom image classification dataset from disk. Nonetheless, it still achieves 70-80% accuracy on the test set after just 30 epochs.
Run it with the Torch GPU backend:
export TORCH_CUDA_VERSION=cu121
cargo run --example custom-image-dataset --release --features tch-gpu
Run it with our WGPU backend:
cargo run --example custom-image-dataset --release --features wgpu