mirror of https://github.com/tracel-ai/burn.git
refactor, feat: clean Cargo.toml files, upgrade tch to 0.10 (#131)
* Clean Cargo.toml files, upgrade tch to 0.10 * Add pull_request hook to test.yml workflow
This commit is contained in:
parent
1ec35a9e1b
commit
85f98b9d54
|
@ -25,5 +25,3 @@ jobs:
|
||||||
run: ./ci/publish.sh ${{ inputs.crate }}
|
run: ./ci/publish.sh ${{ inputs.crate }}
|
||||||
env:
|
env:
|
||||||
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
|
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ name: publish
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- 'v*'
|
- 'v*'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
@ -61,7 +61,7 @@ jobs:
|
||||||
|
|
||||||
publish-burn:
|
publish-burn:
|
||||||
uses: burn-rs/burn/.github/workflows/publish-template.yml@main
|
uses: burn-rs/burn/.github/workflows/publish-template.yml@main
|
||||||
needs:
|
needs:
|
||||||
- publish-burn-derive
|
- publish-burn-derive
|
||||||
- publish-burn-dataset
|
- publish-burn-dataset
|
||||||
- publish-burn-tensor
|
- publish-burn-tensor
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
name: test
|
name: test
|
||||||
|
|
||||||
on: [push]
|
on: [push, pull_request]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-burn-dataset:
|
test-burn-dataset:
|
||||||
|
|
|
@ -4,7 +4,7 @@ version = "0.3.0"
|
||||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||||
description = "Autodiff backend for burn"
|
description = "Autodiff backend for burn"
|
||||||
repository = "https://github.com/burn-rs/burn/tree/main/burn-autodiff"
|
repository = "https://github.com/burn-rs/burn/tree/main/burn-autodiff"
|
||||||
readme="README.md"
|
readme = "README.md"
|
||||||
keywords = ["deep-learning", "machine-learning", "data"]
|
keywords = ["deep-learning", "machine-learning", "data"]
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
license = "MIT/Apache-2.0"
|
license = "MIT/Apache-2.0"
|
||||||
|
@ -17,5 +17,5 @@ export_tests = ["burn-tensor-testgen"]
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn-tensor-testgen = { version = "0.3.0", path = "../burn-tensor-testgen", optional = true }
|
burn-tensor-testgen = { version = "0.3.0", path = "../burn-tensor-testgen", optional = true }
|
||||||
burn-tensor = { version = "0.3.0", path = "../burn-tensor" }
|
burn-tensor = { version = "0.3.0", path = "../burn-tensor" }
|
||||||
derive-new = "0.5"
|
derive-new = "0.5.9"
|
||||||
nanoid = "0.4"
|
nanoid = "0.4.0"
|
||||||
|
|
|
@ -8,7 +8,7 @@ This library provides an easy to use dataset API with many manipulations
|
||||||
to easily create your ML data pipeline.
|
to easily create your ML data pipeline.
|
||||||
"""
|
"""
|
||||||
repository = "https://github.com/burn-rs/burn/tree/main/burn-dataset"
|
repository = "https://github.com/burn-rs/burn/tree/main/burn-dataset"
|
||||||
readme="README.md"
|
readme = "README.md"
|
||||||
keywords = ["deep-learning", "machine-learning", "data"]
|
keywords = ["deep-learning", "machine-learning", "data"]
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -19,11 +19,11 @@ default = ["fake"]
|
||||||
fake = ["dep:fake"]
|
fake = ["dep:fake"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
dirs = "4.0"
|
dirs = "4.0.0"
|
||||||
rand = "0.8.4"
|
rand = "0.8.5"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1.0.151", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1.0.91"
|
||||||
fake = { version = "2.5", optional = true }
|
fake = { version = "2.5.0", optional = true }
|
||||||
thiserror = "1.0"
|
thiserror = "1.0.38"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
|
@ -7,7 +7,7 @@ description = """
|
||||||
Burn derive crate.
|
Burn derive crate.
|
||||||
"""
|
"""
|
||||||
repository = "https://github.com/burn-rs/burn/tree/main/burn-dataset"
|
repository = "https://github.com/burn-rs/burn/tree/main/burn-dataset"
|
||||||
readme="README.md"
|
readme = "README.md"
|
||||||
keywords = []
|
keywords = []
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
license = "MIT/Apache-2.0"
|
license = "MIT/Apache-2.0"
|
||||||
|
@ -17,6 +17,6 @@ edition = "2021"
|
||||||
proc-macro = true
|
proc-macro = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
syn = "1.0"
|
syn = "1.0.107"
|
||||||
quote = "1.0"
|
quote = "1.0.23"
|
||||||
proc-macro2 = "1.0"
|
proc-macro2 = "1.0.49"
|
||||||
|
|
|
@ -5,7 +5,7 @@ authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||||
|
|
||||||
description = "NdArray backend for burn"
|
description = "NdArray backend for burn"
|
||||||
repository = "https://github.com/burn-rs/burn/tree/main/burn-ndarray"
|
repository = "https://github.com/burn-rs/burn/tree/main/burn-ndarray"
|
||||||
readme="README.md"
|
readme = "README.md"
|
||||||
keywords = ["deep-learning", "machine-learning", "data"]
|
keywords = ["deep-learning", "machine-learning", "data"]
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
license = "MIT/Apache-2.0"
|
license = "MIT/Apache-2.0"
|
||||||
|
@ -15,19 +15,27 @@ edition = "2021"
|
||||||
default = []
|
default = []
|
||||||
blas-netlib = ["ndarray/blas", "blas-src/netlib"]
|
blas-netlib = ["ndarray/blas", "blas-src/netlib"]
|
||||||
blas-openblas = ["ndarray/blas", "blas-src/openblas", "openblas-src"]
|
blas-openblas = ["ndarray/blas", "blas-src/openblas", "openblas-src"]
|
||||||
blas-openblas-system = ["ndarray/blas", "blas-src/openblas", "openblas-src/system"]
|
blas-openblas-system = [
|
||||||
|
"ndarray/blas",
|
||||||
|
"blas-src/openblas",
|
||||||
|
"openblas-src/system",
|
||||||
|
]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn-tensor = { version = "0.3.0", path = "../burn-tensor" }
|
burn-tensor = { version = "0.3.0", path = "../burn-tensor" }
|
||||||
blas-src = { version = "0.8.0", default-features = false, optional = true }
|
blas-src = { version = "0.8.0", default-features = false, optional = true }
|
||||||
openblas-src = { version = "0.10", optional = true }
|
openblas-src = { version = "0.10.5", optional = true }
|
||||||
|
|
||||||
ndarray = "0.15"
|
ndarray = "0.15.6"
|
||||||
libm = "0.2"
|
libm = "0.2.6"
|
||||||
derive-new = "0.5"
|
derive-new = "0.5.9"
|
||||||
rand = "0.8"
|
rand = "0.8.5"
|
||||||
num-traits = "0.2"
|
num-traits = "0.2.15"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
burn-tensor = { version = "0.3.0", path = "../burn-tensor", features = ["export_tests"] }
|
burn-tensor = { version = "0.3.0", path = "../burn-tensor", features = [
|
||||||
burn-autodiff = { version = "0.3.0", path = "../burn-autodiff", features = ["export_tests"] }
|
"export_tests",
|
||||||
|
] }
|
||||||
|
burn-autodiff = { version = "0.3.0", path = "../burn-autodiff", features = [
|
||||||
|
"export_tests",
|
||||||
|
] }
|
||||||
|
|
|
@ -5,7 +5,7 @@ authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||||
|
|
||||||
description = "Tch backend for burn"
|
description = "Tch backend for burn"
|
||||||
repository = "https://github.com/burn-rs/burn/tree/main/burn-tch"
|
repository = "https://github.com/burn-rs/burn/tree/main/burn-tch"
|
||||||
readme="README.md"
|
readme = "README.md"
|
||||||
keywords = ["deep-learning", "machine-learning", "data"]
|
keywords = ["deep-learning", "machine-learning", "data"]
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
license = "MIT/Apache-2.0"
|
license = "MIT/Apache-2.0"
|
||||||
|
@ -16,11 +16,17 @@ doc = ["tch/doc-only"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn-tensor = { version = "0.3.0", path = "../burn-tensor", default-features = false }
|
burn-tensor = { version = "0.3.0", path = "../burn-tensor", default-features = false }
|
||||||
rand = "0.8"
|
rand = "0.8.5"
|
||||||
tch = { version = "0.8" }
|
tch = { version = "0.10.1" }
|
||||||
lazy_static = "1.4"
|
lazy_static = "1.4.0"
|
||||||
half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work with tch
|
half = { version = "1.6.0", features = [
|
||||||
|
"num-traits",
|
||||||
|
] } # needs to be 1.6 to work with tch
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
burn-tensor = { version = "0.3.0", path = "../burn-tensor", default-features = false, features = ["export_tests"] }
|
burn-tensor = { version = "0.3.0", path = "../burn-tensor", default-features = false, features = [
|
||||||
burn-autodiff = { version = "0.3.0", path = "../burn-autodiff", default-features = false, features = ["export_tests"] }
|
"export_tests",
|
||||||
|
] }
|
||||||
|
burn-autodiff = { version = "0.3.0", path = "../burn-autodiff", default-features = false, features = [
|
||||||
|
"export_tests",
|
||||||
|
] }
|
||||||
|
|
|
@ -15,10 +15,12 @@ use burn_tensor::backend::Backend;
|
||||||
/// let device_gpu_1 = TchDevice::Cuda(0); // First GPU
|
/// let device_gpu_1 = TchDevice::Cuda(0); // First GPU
|
||||||
/// let device_gpu_2 = TchDevice::Cuda(1); // Second GPU
|
/// let device_gpu_2 = TchDevice::Cuda(1); // Second GPU
|
||||||
/// let device_cpu = TchDevice::Cpu; // CPU
|
/// let device_cpu = TchDevice::Cpu; // CPU
|
||||||
|
/// let device_mps = TchDevice::Mps; // Metal Performance Shaders
|
||||||
/// ```
|
/// ```
|
||||||
pub enum TchDevice {
|
pub enum TchDevice {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda(usize),
|
Cuda(usize),
|
||||||
|
Mps,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<TchDevice> for tch::Device {
|
impl From<TchDevice> for tch::Device {
|
||||||
|
@ -26,6 +28,17 @@ impl From<TchDevice> for tch::Device {
|
||||||
match device {
|
match device {
|
||||||
TchDevice::Cpu => tch::Device::Cpu,
|
TchDevice::Cpu => tch::Device::Cpu,
|
||||||
TchDevice::Cuda(num) => tch::Device::Cuda(num),
|
TchDevice::Cuda(num) => tch::Device::Cuda(num),
|
||||||
|
TchDevice::Mps => tch::Device::Mps,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<tch::Device> for TchDevice {
|
||||||
|
fn from(device: tch::Device) -> Self {
|
||||||
|
match device {
|
||||||
|
tch::Device::Cpu => TchDevice::Cpu,
|
||||||
|
tch::Device::Cuda(num) => TchDevice::Cuda(num),
|
||||||
|
tch::Device::Mps => TchDevice::Mps,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,22 +4,14 @@ use std::ops::{Add, Div, Mul, Range, Sub};
|
||||||
|
|
||||||
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||||
fn from_data<const D: usize>(data: Data<E, D>, device: TchDevice) -> TchTensor<E, D> {
|
fn from_data<const D: usize>(data: Data<E, D>, device: TchDevice) -> TchTensor<E, D> {
|
||||||
let device = match device {
|
TchTensor::from_data(data, device.into())
|
||||||
TchDevice::Cpu => tch::Device::Cpu,
|
|
||||||
TchDevice::Cuda(num) => tch::Device::Cuda(num),
|
|
||||||
};
|
|
||||||
TchTensor::from_data(data, device)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_data_bool<const D: usize>(
|
fn from_data_bool<const D: usize>(
|
||||||
data: Data<bool, D>,
|
data: Data<bool, D>,
|
||||||
device: TchDevice,
|
device: TchDevice,
|
||||||
) -> TchTensor<bool, D> {
|
) -> TchTensor<bool, D> {
|
||||||
let device = match device {
|
TchTensor::from_data(data, device.into())
|
||||||
TchDevice::Cpu => tch::Device::Cpu,
|
|
||||||
TchDevice::Cuda(num) => tch::Device::Cuda(num),
|
|
||||||
};
|
|
||||||
TchTensor::from_data(data, device)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn random<const D: usize>(
|
fn random<const D: usize>(
|
||||||
|
@ -47,7 +39,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||||
}
|
}
|
||||||
Distribution::Normal(mean, std) => {
|
Distribution::Normal(mean, std) => {
|
||||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||||
tensor.tensor = tensor.tensor.normal(mean, std);
|
tensor.tensor = tensor.tensor.normal_(mean, std);
|
||||||
tensor
|
tensor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -107,13 +99,9 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||||
tensor: &TchTensor<bool, D>,
|
tensor: &TchTensor<bool, D>,
|
||||||
device: TchDevice,
|
device: TchDevice,
|
||||||
) -> TchTensor<bool, D> {
|
) -> TchTensor<bool, D> {
|
||||||
let device = match device {
|
|
||||||
TchDevice::Cpu => tch::Device::Cpu,
|
|
||||||
TchDevice::Cuda(num) => tch::Device::Cuda(num),
|
|
||||||
};
|
|
||||||
TchTensor {
|
TchTensor {
|
||||||
kind: tensor.kind,
|
kind: tensor.kind,
|
||||||
tensor: tensor.tensor.to(device),
|
tensor: tensor.tensor.to(device.into()),
|
||||||
shape: tensor.shape,
|
shape: tensor.shape,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -134,20 +122,13 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {
|
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {
|
||||||
match tensor.tensor.device() {
|
tensor.tensor.device().into()
|
||||||
tch::Device::Cpu => TchDevice::Cpu,
|
|
||||||
tch::Device::Cuda(num) => TchDevice::Cuda(num),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_device<const D: usize>(tensor: &TchTensor<E, D>, device: TchDevice) -> TchTensor<E, D> {
|
fn to_device<const D: usize>(tensor: &TchTensor<E, D>, device: TchDevice) -> TchTensor<E, D> {
|
||||||
let device = match device {
|
|
||||||
TchDevice::Cpu => tch::Device::Cpu,
|
|
||||||
TchDevice::Cuda(num) => tch::Device::Cuda(num),
|
|
||||||
};
|
|
||||||
TchTensor {
|
TchTensor {
|
||||||
kind: tensor.kind,
|
kind: tensor.kind,
|
||||||
tensor: tensor.tensor.to(device),
|
tensor: tensor.tensor.to(device.into()),
|
||||||
shape: tensor.shape,
|
shape: tensor.shape,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -419,16 +400,18 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mean_dim<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
fn mean_dim<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||||
let tensor = tensor
|
let tensor =
|
||||||
.tensor
|
tensor
|
||||||
.mean_dim(&[dim as i64], true, tensor.kind.kind());
|
.tensor
|
||||||
|
.mean_dim(Some([dim as i64].as_slice()), true, tensor.kind.kind());
|
||||||
to_tensor(tensor)
|
to_tensor(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sum_dim<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
fn sum_dim<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||||
let tensor = tensor
|
let tensor =
|
||||||
.tensor
|
tensor
|
||||||
.sum_dim_intlist(&[dim as i64], true, tensor.kind.kind());
|
.tensor
|
||||||
|
.sum_dim_intlist(Some([dim as i64].as_slice()), true, tensor.kind.kind());
|
||||||
to_tensor(tensor)
|
to_tensor(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||||
|
|
||||||
description = "Burn tensor test gen crate."
|
description = "Burn tensor test gen crate."
|
||||||
repository = "https://github.com/burn-rs/burn/tree/main/burn-tensor-testgen"
|
repository = "https://github.com/burn-rs/burn/tree/main/burn-tensor-testgen"
|
||||||
readme="README.md"
|
readme = "README.md"
|
||||||
license = "MIT/Apache-2.0"
|
license = "MIT/Apache-2.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
@ -13,6 +13,6 @@ edition = "2021"
|
||||||
proc-macro = true
|
proc-macro = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
syn = "1.0"
|
syn = "1.0.107"
|
||||||
quote = "1.0"
|
quote = "1.0.23"
|
||||||
proc-macro2 = "1.0"
|
proc-macro2 = "1.0.49"
|
||||||
|
|
|
@ -8,7 +8,7 @@ This library provides multiple tensor implementations hidden behind
|
||||||
an easy to use API that supports reverse mode automatic differentiation.
|
an easy to use API that supports reverse mode automatic differentiation.
|
||||||
"""
|
"""
|
||||||
repository = "https://github.com/burn-rs/burn/tree/main/burn-tensor"
|
repository = "https://github.com/burn-rs/burn/tree/main/burn-tensor"
|
||||||
readme="README.md"
|
readme = "README.md"
|
||||||
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
license = "MIT/Apache-2.0"
|
license = "MIT/Apache-2.0"
|
||||||
|
@ -21,14 +21,16 @@ experimental-named-tensor = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn-tensor-testgen = { version = "0.3.0", path = "../burn-tensor-testgen", optional = true }
|
burn-tensor-testgen = { version = "0.3.0", path = "../burn-tensor-testgen", optional = true }
|
||||||
num-traits = "0.2"
|
num-traits = "0.2.15"
|
||||||
derive-new = "0.5"
|
derive-new = "0.5.9"
|
||||||
rand = "0.8"
|
rand = "0.8.5"
|
||||||
statrs = "0.16"
|
statrs = "0.16.0"
|
||||||
half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work with tch
|
half = { version = "1.6.0", features = [
|
||||||
|
"num-traits",
|
||||||
|
] } # needs to be 1.6 to work with tch
|
||||||
|
|
||||||
# Autodiff
|
# Autodiff
|
||||||
nanoid = "0.4"
|
nanoid = "0.4.0"
|
||||||
|
|
||||||
# Serialization
|
# Serialization
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0.151", features = ["derive"] }
|
||||||
|
|
|
@ -4,7 +4,7 @@ version = "0.3.0"
|
||||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||||
description = "BURN: Burn Unstoppable Rusty Neurons"
|
description = "BURN: Burn Unstoppable Rusty Neurons"
|
||||||
repository = "https://github.com/burn-rs/burn"
|
repository = "https://github.com/burn-rs/burn"
|
||||||
readme="README.md"
|
readme = "README.md"
|
||||||
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
license = "MIT/Apache-2.0"
|
license = "MIT/Apache-2.0"
|
||||||
|
@ -20,30 +20,32 @@ burn-autodiff = { version = "0.3.0", path = "../burn-autodiff" }
|
||||||
burn-dataset = { version = "0.3.0", path = "../burn-dataset", default-features = false }
|
burn-dataset = { version = "0.3.0", path = "../burn-dataset", default-features = false }
|
||||||
burn-derive = { version = "0.3.0", path = "../burn-derive" }
|
burn-derive = { version = "0.3.0", path = "../burn-derive" }
|
||||||
|
|
||||||
thiserror = "1.0"
|
thiserror = "1.0.38"
|
||||||
num-traits = "0.2"
|
num-traits = "0.2.15"
|
||||||
derive-new = "0.5"
|
derive-new = "0.5.9"
|
||||||
rand = "0.8"
|
rand = "0.8.5"
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
nvml-wrapper = "0.8"
|
nvml-wrapper = "0.8.0"
|
||||||
textplots = "0.8"
|
textplots = "0.8.0"
|
||||||
rgb = "0.8"
|
rgb = "0.8.34"
|
||||||
terminal_size = "0.2"
|
terminal_size = "0.2.3"
|
||||||
|
|
||||||
# Console
|
# Console
|
||||||
indicatif = "0.17"
|
indicatif = "0.17.2"
|
||||||
log4rs = "1.2"
|
log4rs = "1.2.0"
|
||||||
log = "0.4"
|
log = "0.4.17"
|
||||||
|
|
||||||
# Serialize Deserialize
|
# Serialize Deserialize
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0.151", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0.91"
|
||||||
flate2 = "1.0"
|
flate2 = "1.0.25"
|
||||||
|
|
||||||
# Parameter & Optimization
|
# Parameter & Optimization
|
||||||
nanoid = "0.4"
|
nanoid = "0.4.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
burn-dataset = { version = "0.3.0", path = "../burn-dataset", features = ["fake"] }
|
burn-dataset = { version = "0.3.0", path = "../burn-dataset", features = [
|
||||||
|
"fake",
|
||||||
|
] }
|
||||||
burn-ndarray = { version = "0.3.0", path = "../burn-ndarray" }
|
burn-ndarray = { version = "0.3.0", path = "../burn-ndarray" }
|
||||||
|
|
|
@ -21,4 +21,4 @@ burn-tch = { path = "../../burn-tch", optional = true }
|
||||||
burn-ndarray = { path = "../../burn-ndarray", optional = true }
|
burn-ndarray = { path = "../../burn-ndarray", optional = true }
|
||||||
|
|
||||||
# Serialization
|
# Serialization
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0.151", features = ["derive"] }
|
||||||
|
|
|
@ -21,7 +21,10 @@ mod tch_gpu {
|
||||||
use mnist::training;
|
use mnist::training;
|
||||||
|
|
||||||
pub fn run() {
|
pub fn run() {
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
let device = TchDevice::Cuda(0);
|
let device = TchDevice::Cuda(0);
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
let device = TchDevice::Mps;
|
||||||
training::run::<ADBackendDecorator<TchBackend<burn::tensor::f16>>>(device);
|
training::run::<ADBackendDecorator<TchBackend<burn::tensor::f16>>>(device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,4 +12,4 @@ burn-autodiff = { path = "../../burn-autodiff" }
|
||||||
burn-ndarray = { path = "../../burn-ndarray" }
|
burn-ndarray = { path = "../../burn-ndarray" }
|
||||||
|
|
||||||
# Serialization
|
# Serialization
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0.151", features = ["derive"] }
|
||||||
|
|
|
@ -16,8 +16,11 @@ burn-autodiff = { path = "../../burn-autodiff" }
|
||||||
burn-tch = { path = "../../burn-tch" }
|
burn-tch = { path = "../../burn-tch" }
|
||||||
|
|
||||||
# Tokenizer
|
# Tokenizer
|
||||||
tokenizers = { version = "0.13", default-features = false, features = ["onig", "http"] }
|
tokenizers = { version = "0.13.2", default-features = false, features = [
|
||||||
|
"onig",
|
||||||
|
"http",
|
||||||
|
] }
|
||||||
|
|
||||||
# Utils
|
# Utils
|
||||||
derive-new = "0.5"
|
derive-new = "0.5.9"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0.151", features = ["derive"] }
|
||||||
|
|
|
@ -13,7 +13,11 @@ fn main() {
|
||||||
);
|
);
|
||||||
|
|
||||||
text_classification::training::train::<Backend, AgNewsDataset>(
|
text_classification::training::train::<Backend, AgNewsDataset>(
|
||||||
burn_tch::TchDevice::Cuda(0),
|
if cfg!(target_os = "macos") {
|
||||||
|
burn_tch::TchDevice::Mps
|
||||||
|
} else {
|
||||||
|
burn_tch::TchDevice::Cuda(0)
|
||||||
|
},
|
||||||
AgNewsDataset::train(),
|
AgNewsDataset::train(),
|
||||||
AgNewsDataset::test(),
|
AgNewsDataset::test(),
|
||||||
config,
|
config,
|
||||||
|
|
|
@ -13,7 +13,11 @@ fn main() {
|
||||||
);
|
);
|
||||||
|
|
||||||
text_classification::training::train::<Backend, DbPediaDataset>(
|
text_classification::training::train::<Backend, DbPediaDataset>(
|
||||||
burn_tch::TchDevice::Cuda(0),
|
if cfg!(target_os = "macos") {
|
||||||
|
burn_tch::TchDevice::Mps
|
||||||
|
} else {
|
||||||
|
burn_tch::TchDevice::Cuda(0)
|
||||||
|
},
|
||||||
DbPediaDataset::train(),
|
DbPediaDataset::train(),
|
||||||
DbPediaDataset::test(),
|
DbPediaDataset::test(),
|
||||||
config,
|
config,
|
||||||
|
|
|
@ -14,11 +14,14 @@ default = []
|
||||||
burn = { path = "../../burn" }
|
burn = { path = "../../burn" }
|
||||||
burn-autodiff = { path = "../../burn-autodiff" }
|
burn-autodiff = { path = "../../burn-autodiff" }
|
||||||
burn-tch = { path = "../../burn-tch" }
|
burn-tch = { path = "../../burn-tch" }
|
||||||
log = "0.4"
|
log = "0.4.17"
|
||||||
|
|
||||||
# Tokenizer
|
# Tokenizer
|
||||||
tokenizers = { version = "0.13", default-features = false, features = ["onig", "http"] }
|
tokenizers = { version = "0.13.2", default-features = false, features = [
|
||||||
|
"onig",
|
||||||
|
"http",
|
||||||
|
] }
|
||||||
|
|
||||||
# Utils
|
# Utils
|
||||||
derive-new = "0.5"
|
derive-new = "0.5.9"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0.151", features = ["derive"] }
|
||||||
|
|
|
@ -13,7 +13,11 @@ fn main() {
|
||||||
);
|
);
|
||||||
|
|
||||||
text_generation::training::train::<Backend, DbPediaDataset>(
|
text_generation::training::train::<Backend, DbPediaDataset>(
|
||||||
burn_tch::TchDevice::Cuda(0),
|
if cfg!(target_os = "macos") {
|
||||||
|
burn_tch::TchDevice::Mps
|
||||||
|
} else {
|
||||||
|
burn_tch::TchDevice::Cuda(0)
|
||||||
|
},
|
||||||
DbPediaDataset::train(),
|
DbPediaDataset::train(),
|
||||||
DbPediaDataset::test(),
|
DbPediaDataset::test(),
|
||||||
config,
|
config,
|
||||||
|
|
Loading…
Reference in New Issue