Fix/devices api (#990)

This commit is contained in:
Nathaniel Simard 2023-11-22 10:24:24 -05:00 committed by GitHub
parent 3d6c738776
commit 630044e96b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 42 additions and 25 deletions

View File

@ -84,9 +84,14 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Type to save and load the module.
type Record: Record;
/// Collects devices in the given vector and returns it with the devices found in the module
/// structure without duplicates.
fn devices(&self, devices: Devices<B>) -> Devices<B>;
/// Return all the devices found in the underneath module tree added to the given vector
/// without duplicates.
fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
/// Return all the devices found in the underneath module tree without duplicates.
fn devices(&self) -> Devices<B> {
self.collect_devices(Devices::<B>::new())
}
/// Fork the module and all of its sub-modules to the given device.
///

View File

@ -75,7 +75,7 @@ macro_rules! constant {
self
}
fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
devices
}
};
@ -147,7 +147,7 @@ impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
self.to_device(device)
}
fn devices(&self, mut devices: Devices<B>) -> Devices<B> {
fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
let device = self.device();
if !devices.contains(&device) {
@ -195,7 +195,7 @@ impl<B: Backend> Module<B> for PhantomData<B> {
self
}
fn devices(&self, devices: Devices<B>) -> Devices<B> {
fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
devices
}
}

View File

@ -37,9 +37,9 @@ where
self.map(|module| module.fork(device))
}
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
if let Some(module) = self.as_ref() {
devices = module.devices(devices);
devices = module.collect_devices(devices);
}
devices
@ -105,9 +105,9 @@ where
self.into_iter().map(|module| module.fork(device)).collect()
}
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
for module in self.iter() {
devices = module.devices(devices);
devices = module.collect_devices(devices);
}
devices
@ -134,9 +134,9 @@ where
{
type Record = [T::Record; N];
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
for module in self.iter() {
devices = module.devices(devices);
devices = module.collect_devices(devices);
}
devices

View File

@ -95,7 +95,10 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
self.to_device(device) // Same thing here since no grad.
}
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
fn collect_devices(
&self,
mut devices: Vec<<B as Backend>::Device>,
) -> Vec<<B as Backend>::Device> {
let device = self.value.read().unwrap().device();
if !devices.contains(&device) {

View File

@ -75,7 +75,10 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
})
}
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
fn collect_devices(
&self,
mut devices: Vec<<B as Backend>::Device>,
) -> Vec<<B as Backend>::Device> {
let device = self.device();
if !devices.contains(&device) {
@ -122,7 +125,10 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
self.to_device(device) // Don't support autodiff.
}
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
fn collect_devices(
&self,
mut devices: Vec<<B as Backend>::Device>,
) -> Vec<<B as Backend>::Device> {
let device = self.device();
if !devices.contains(&device) {
@ -169,7 +175,10 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
self.to_device(device) // Don't support autodiff.
}
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
fn collect_devices(
&self,
mut devices: Vec<<B as Backend>::Device>,
) -> Vec<<B as Backend>::Device> {
let device = self.device();
if !devices.contains(&device) {

View File

@ -29,7 +29,7 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let num_params_fn = generator.gen_num_params();
let visit = generator.gen_visit();
let map_mut = generator.gen_map();
let devices = generator.gen_devices();
let collect_devices = generator.gen_collect_devices();
let to_device = generator.gen_to_device();
let fork = generator.gen_fork();
let valid_fn = generator.gen_valid();
@ -54,7 +54,7 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
#visit
#map_mut
#devices
#collect_devices
#to_device
#fork
}

View File

@ -4,7 +4,7 @@ use proc_macro2::TokenStream;
pub(crate) trait ModuleCodegen {
fn gen_num_params(&self) -> TokenStream;
fn gen_visit(&self) -> TokenStream;
fn gen_devices(&self) -> TokenStream;
fn gen_collect_devices(&self) -> TokenStream;
fn gen_to_device(&self) -> TokenStream;
fn gen_fork(&self) -> TokenStream;
fn gen_map(&self) -> TokenStream;

View File

@ -39,15 +39,15 @@ impl ModuleCodegen for StructModuleCodegen {
}
}
fn gen_devices(&self) -> TokenStream {
fn gen_collect_devices(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
let devices = burn::module::Module::<B>::devices(&self.#name, devices);
let devices = burn::module::Module::<B>::collect_devices(&self.#name, devices);
}
});
quote! {
fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
#body
devices

View File

@ -88,7 +88,7 @@ impl<B: Backend> TextClassificationModel<B> {
pub fn forward(&self, item: TextClassificationTrainingBatch<B>) -> ClassificationOutput<B> {
// Get batch and sequence length, and the device
let [batch_size, seq_length] = item.tokens.dims();
let device = &self.embedding_token.devices(Vec::new())[0];
let device = &self.embedding_token.devices()[0];
// Move tensors to the correct device
let tokens = item.tokens.to_device(device);
@ -128,7 +128,7 @@ impl<B: Backend> TextClassificationModel<B> {
pub fn infer(&self, item: TextClassificationInferenceBatch<B>) -> Tensor<B, 2> {
// Get batch and sequence length, and the device
let [batch_size, seq_length] = item.tokens.dims();
let device = &self.embedding_token.devices(Vec::new())[0];
let device = &self.embedding_token.devices()[0];
// Move tensors to the correct device
let tokens = item.tokens.to_device(device);

View File

@ -58,7 +58,7 @@ impl<B: Backend> TextGenerationModel<B> {
item: TrainingTextGenerationBatch<B>,
) -> ClassificationOutput<B> {
let [batch_size, seq_length] = item.tokens_inputs.dims();
let device = &self.devices(Vec::new())[0];
let device = &self.devices()[0];
let inputs = item.tokens_inputs.to_device(device);
let targets = item.targets.to_device(device);