refactor, feat: clean Cargo.toml files, upgrade tch to 0.10 (#131)

* Clean Cargo.toml files, upgrade tch to 0.10

* Add pull_request hook to test.yml workflow
This commit is contained in:
Visual 2022-12-25 17:36:23 +02:00 committed by GitHub
parent 1ec35a9e1b
commit 85f98b9d54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 141 additions and 108 deletions

View File

@ -25,5 +25,3 @@ jobs:
run: ./ci/publish.sh ${{ inputs.crate }}
env:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

View File

@ -1,6 +1,6 @@
name: test
on: [push]
on: [push, pull_request]
jobs:
test-burn-dataset:

View File

@ -4,7 +4,7 @@ version = "0.3.0"
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
description = "Autodiff backend for burn"
repository = "https://github.com/burn-rs/burn/tree/main/burn-autodiff"
readme="README.md"
readme = "README.md"
keywords = ["deep-learning", "machine-learning", "data"]
categories = ["science"]
license = "MIT/Apache-2.0"
@ -17,5 +17,5 @@ export_tests = ["burn-tensor-testgen"]
[dependencies]
burn-tensor-testgen = { version = "0.3.0", path = "../burn-tensor-testgen", optional = true }
burn-tensor = { version = "0.3.0", path = "../burn-tensor" }
derive-new = "0.5"
nanoid = "0.4"
derive-new = "0.5.9"
nanoid = "0.4.0"

View File

@ -8,7 +8,7 @@ This library provides an easy to use dataset API with many manipulations
to easily create your ML data pipeline.
"""
repository = "https://github.com/burn-rs/burn/tree/main/burn-dataset"
readme="README.md"
readme = "README.md"
keywords = ["deep-learning", "machine-learning", "data"]
categories = ["science"]
license = "MIT"
@ -19,11 +19,11 @@ default = ["fake"]
fake = ["dep:fake"]
[dependencies]
dirs = "4.0"
rand = "0.8.4"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
fake = { version = "2.5", optional = true }
thiserror = "1.0"
dirs = "4.0.0"
rand = "0.8.5"
serde = { version = "1.0.151", features = ["derive"] }
serde_json = "1.0.91"
fake = { version = "2.5.0", optional = true }
thiserror = "1.0.38"
[dev-dependencies]

View File

@ -7,7 +7,7 @@ description = """
Burn derive crate.
"""
repository = "https://github.com/burn-rs/burn/tree/main/burn-dataset"
readme="README.md"
readme = "README.md"
keywords = []
categories = ["science"]
license = "MIT/Apache-2.0"
@ -17,6 +17,6 @@ edition = "2021"
proc-macro = true
[dependencies]
syn = "1.0"
quote = "1.0"
proc-macro2 = "1.0"
syn = "1.0.107"
quote = "1.0.23"
proc-macro2 = "1.0.49"

View File

@ -5,7 +5,7 @@ authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
description = "NdArray backend for burn"
repository = "https://github.com/burn-rs/burn/tree/main/burn-ndarray"
readme="README.md"
readme = "README.md"
keywords = ["deep-learning", "machine-learning", "data"]
categories = ["science"]
license = "MIT/Apache-2.0"
@ -15,19 +15,27 @@ edition = "2021"
default = []
blas-netlib = ["ndarray/blas", "blas-src/netlib"]
blas-openblas = ["ndarray/blas", "blas-src/openblas", "openblas-src"]
blas-openblas-system = ["ndarray/blas", "blas-src/openblas", "openblas-src/system"]
blas-openblas-system = [
"ndarray/blas",
"blas-src/openblas",
"openblas-src/system",
]
[dependencies]
burn-tensor = { version = "0.3.0", path = "../burn-tensor" }
blas-src = { version = "0.8.0", default-features = false, optional = true }
openblas-src = { version = "0.10", optional = true }
openblas-src = { version = "0.10.5", optional = true }
ndarray = "0.15"
libm = "0.2"
derive-new = "0.5"
rand = "0.8"
num-traits = "0.2"
ndarray = "0.15.6"
libm = "0.2.6"
derive-new = "0.5.9"
rand = "0.8.5"
num-traits = "0.2.15"
[dev-dependencies]
burn-tensor = { version = "0.3.0", path = "../burn-tensor", features = ["export_tests"] }
burn-autodiff = { version = "0.3.0", path = "../burn-autodiff", features = ["export_tests"] }
burn-tensor = { version = "0.3.0", path = "../burn-tensor", features = [
"export_tests",
] }
burn-autodiff = { version = "0.3.0", path = "../burn-autodiff", features = [
"export_tests",
] }

View File

@ -5,7 +5,7 @@ authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
description = "Tch backend for burn"
repository = "https://github.com/burn-rs/burn/tree/main/burn-tch"
readme="README.md"
readme = "README.md"
keywords = ["deep-learning", "machine-learning", "data"]
categories = ["science"]
license = "MIT/Apache-2.0"
@ -16,11 +16,17 @@ doc = ["tch/doc-only"]
[dependencies]
burn-tensor = { version = "0.3.0", path = "../burn-tensor", default-features = false }
rand = "0.8"
tch = { version = "0.8" }
lazy_static = "1.4"
half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work with tch
rand = "0.8.5"
tch = { version = "0.10.1" }
lazy_static = "1.4.0"
half = { version = "1.6.0", features = [
"num-traits",
] } # needs to be 1.6 to work with tch
[dev-dependencies]
burn-tensor = { version = "0.3.0", path = "../burn-tensor", default-features = false, features = ["export_tests"] }
burn-autodiff = { version = "0.3.0", path = "../burn-autodiff", default-features = false, features = ["export_tests"] }
burn-tensor = { version = "0.3.0", path = "../burn-tensor", default-features = false, features = [
"export_tests",
] }
burn-autodiff = { version = "0.3.0", path = "../burn-autodiff", default-features = false, features = [
"export_tests",
] }

View File

@ -15,10 +15,12 @@ use burn_tensor::backend::Backend;
/// let device_gpu_1 = TchDevice::Cuda(0); // First GPU
/// let device_gpu_2 = TchDevice::Cuda(1); // Second GPU
/// let device_cpu = TchDevice::Cpu; // CPU
/// let device_mps = TchDevice::Mps; // Metal Performance Shaders
/// ```
pub enum TchDevice {
Cpu,
Cuda(usize),
Mps,
}
impl From<TchDevice> for tch::Device {
@ -26,6 +28,17 @@ impl From<TchDevice> for tch::Device {
match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
TchDevice::Mps => tch::Device::Mps,
}
}
}
impl From<tch::Device> for TchDevice {
fn from(device: tch::Device) -> Self {
match device {
tch::Device::Cpu => TchDevice::Cpu,
tch::Device::Cuda(num) => TchDevice::Cuda(num),
tch::Device::Mps => TchDevice::Mps,
}
}
}

View File

@ -4,22 +4,14 @@ use std::ops::{Add, Div, Mul, Range, Sub};
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn from_data<const D: usize>(data: Data<E, D>, device: TchDevice) -> TchTensor<E, D> {
let device = match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
};
TchTensor::from_data(data, device)
TchTensor::from_data(data, device.into())
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: TchDevice,
) -> TchTensor<bool, D> {
let device = match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
};
TchTensor::from_data(data, device)
TchTensor::from_data(data, device.into())
}
fn random<const D: usize>(
@ -47,7 +39,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
}
Distribution::Normal(mean, std) => {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
tensor.tensor = tensor.tensor.normal(mean, std);
tensor.tensor = tensor.tensor.normal_(mean, std);
tensor
}
}
@ -107,13 +99,9 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
tensor: &TchTensor<bool, D>,
device: TchDevice,
) -> TchTensor<bool, D> {
let device = match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
};
TchTensor {
kind: tensor.kind,
tensor: tensor.tensor.to(device),
tensor: tensor.tensor.to(device.into()),
shape: tensor.shape,
}
}
@ -134,20 +122,13 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
}
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {
match tensor.tensor.device() {
tch::Device::Cpu => TchDevice::Cpu,
tch::Device::Cuda(num) => TchDevice::Cuda(num),
}
tensor.tensor.device().into()
}
fn to_device<const D: usize>(tensor: &TchTensor<E, D>, device: TchDevice) -> TchTensor<E, D> {
let device = match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
};
TchTensor {
kind: tensor.kind,
tensor: tensor.tensor.to(device),
tensor: tensor.tensor.to(device.into()),
shape: tensor.shape,
}
}
@ -419,16 +400,18 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
}
fn mean_dim<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
let tensor = tensor
.tensor
.mean_dim(&[dim as i64], true, tensor.kind.kind());
let tensor =
tensor
.tensor
.mean_dim(Some([dim as i64].as_slice()), true, tensor.kind.kind());
to_tensor(tensor)
}
fn sum_dim<const D: usize>(tensor: &TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
let tensor = tensor
.tensor
.sum_dim_intlist(&[dim as i64], true, tensor.kind.kind());
let tensor =
tensor
.tensor
.sum_dim_intlist(Some([dim as i64].as_slice()), true, tensor.kind.kind());
to_tensor(tensor)
}

View File

@ -5,7 +5,7 @@ authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
description = "Burn tensor test gen crate."
repository = "https://github.com/burn-rs/burn/tree/main/burn-tensor-testgen"
readme="README.md"
readme = "README.md"
license = "MIT/Apache-2.0"
edition = "2021"
@ -13,6 +13,6 @@ edition = "2021"
proc-macro = true
[dependencies]
syn = "1.0"
quote = "1.0"
proc-macro2 = "1.0"
syn = "1.0.107"
quote = "1.0.23"
proc-macro2 = "1.0.49"

View File

@ -8,7 +8,7 @@ This library provides multiple tensor implementations hidden behind
an easy to use API that supports reverse mode automatic differentiation.
"""
repository = "https://github.com/burn-rs/burn/tree/main/burn-tensor"
readme="README.md"
readme = "README.md"
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
categories = ["science"]
license = "MIT/Apache-2.0"
@ -21,14 +21,16 @@ experimental-named-tensor = []
[dependencies]
burn-tensor-testgen = { version = "0.3.0", path = "../burn-tensor-testgen", optional = true }
num-traits = "0.2"
derive-new = "0.5"
rand = "0.8"
statrs = "0.16"
half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work with tch
num-traits = "0.2.15"
derive-new = "0.5.9"
rand = "0.8.5"
statrs = "0.16.0"
half = { version = "1.6.0", features = [
"num-traits",
] } # needs to be 1.6 to work with tch
# Autodiff
nanoid = "0.4"
nanoid = "0.4.0"
# Serialization
serde = { version = "1.0", features = ["derive"] }
serde = { version = "1.0.151", features = ["derive"] }

View File

@ -4,7 +4,7 @@ version = "0.3.0"
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
description = "BURN: Burn Unstoppable Rusty Neurons"
repository = "https://github.com/burn-rs/burn"
readme="README.md"
readme = "README.md"
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
categories = ["science"]
license = "MIT/Apache-2.0"
@ -20,30 +20,32 @@ burn-autodiff = { version = "0.3.0", path = "../burn-autodiff" }
burn-dataset = { version = "0.3.0", path = "../burn-dataset", default-features = false }
burn-derive = { version = "0.3.0", path = "../burn-derive" }
thiserror = "1.0"
num-traits = "0.2"
derive-new = "0.5"
rand = "0.8"
thiserror = "1.0.38"
num-traits = "0.2.15"
derive-new = "0.5.9"
rand = "0.8.5"
# Metrics
nvml-wrapper = "0.8"
textplots = "0.8"
rgb = "0.8"
terminal_size = "0.2"
nvml-wrapper = "0.8.0"
textplots = "0.8.0"
rgb = "0.8.34"
terminal_size = "0.2.3"
# Console
indicatif = "0.17"
log4rs = "1.2"
log = "0.4"
indicatif = "0.17.2"
log4rs = "1.2.0"
log = "0.4.17"
# Serialize Deserialize
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
flate2 = "1.0"
serde = { version = "1.0.151", features = ["derive"] }
serde_json = "1.0.91"
flate2 = "1.0.25"
# Parameter & Optimization
nanoid = "0.4"
nanoid = "0.4.0"
[dev-dependencies]
burn-dataset = { version = "0.3.0", path = "../burn-dataset", features = ["fake"] }
burn-dataset = { version = "0.3.0", path = "../burn-dataset", features = [
"fake",
] }
burn-ndarray = { version = "0.3.0", path = "../burn-ndarray" }

View File

@ -21,4 +21,4 @@ burn-tch = { path = "../../burn-tch", optional = true }
burn-ndarray = { path = "../../burn-ndarray", optional = true }
# Serialization
serde = { version = "1.0", features = ["derive"] }
serde = { version = "1.0.151", features = ["derive"] }

View File

@ -21,7 +21,10 @@ mod tch_gpu {
use mnist::training;
pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = TchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = TchDevice::Mps;
training::run::<ADBackendDecorator<TchBackend<burn::tensor::f16>>>(device);
}
}

View File

@ -12,4 +12,4 @@ burn-autodiff = { path = "../../burn-autodiff" }
burn-ndarray = { path = "../../burn-ndarray" }
# Serialization
serde = { version = "1.0", features = ["derive"] }
serde = { version = "1.0.151", features = ["derive"] }

View File

@ -16,8 +16,11 @@ burn-autodiff = { path = "../../burn-autodiff" }
burn-tch = { path = "../../burn-tch" }
# Tokenizer
tokenizers = { version = "0.13", default-features = false, features = ["onig", "http"] }
tokenizers = { version = "0.13.2", default-features = false, features = [
"onig",
"http",
] }
# Utils
derive-new = "0.5"
serde = { version = "1.0", features = ["derive"] }
derive-new = "0.5.9"
serde = { version = "1.0.151", features = ["derive"] }

View File

@ -13,7 +13,11 @@ fn main() {
);
text_classification::training::train::<Backend, AgNewsDataset>(
burn_tch::TchDevice::Cuda(0),
if cfg!(target_os = "macos") {
burn_tch::TchDevice::Mps
} else {
burn_tch::TchDevice::Cuda(0)
},
AgNewsDataset::train(),
AgNewsDataset::test(),
config,

View File

@ -13,7 +13,11 @@ fn main() {
);
text_classification::training::train::<Backend, DbPediaDataset>(
burn_tch::TchDevice::Cuda(0),
if cfg!(target_os = "macos") {
burn_tch::TchDevice::Mps
} else {
burn_tch::TchDevice::Cuda(0)
},
DbPediaDataset::train(),
DbPediaDataset::test(),
config,

View File

@ -14,11 +14,14 @@ default = []
burn = { path = "../../burn" }
burn-autodiff = { path = "../../burn-autodiff" }
burn-tch = { path = "../../burn-tch" }
log = "0.4"
log = "0.4.17"
# Tokenizer
tokenizers = { version = "0.13", default-features = false, features = ["onig", "http"] }
tokenizers = { version = "0.13.2", default-features = false, features = [
"onig",
"http",
] }
# Utils
derive-new = "0.5"
serde = { version = "1.0", features = ["derive"] }
derive-new = "0.5.9"
serde = { version = "1.0.151", features = ["derive"] }

View File

@ -13,7 +13,11 @@ fn main() {
);
text_generation::training::train::<Backend, DbPediaDataset>(
burn_tch::TchDevice::Cuda(0),
if cfg!(target_os = "macos") {
burn_tch::TchDevice::Mps
} else {
burn_tch::TchDevice::Cuda(0)
},
DbPediaDataset::train(),
DbPediaDataset::test(),
config,