This commit is contained in:
Caio Piccirillo 2023-08-08 23:57:51 +02:00 committed by GitHub
parent 441a7011ce
commit 1d3bbaab13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 105 additions and 84 deletions

13
.github/workflows/typos.yml vendored Normal file
View File

@ -0,0 +1,13 @@
name: Typos
on: pull_request
jobs:
run:
name: Spell check with Typos
runs-on: ubuntu-20.04
steps:
- name: Checkout Actions Repository
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
- name: Check spelling
uses: crate-ci/typos@8a7996b4bcfa526668e5a9e7914330428897e205

View File

@ -16,7 +16,7 @@ __Sections__
Modules are a way of creating neural network structures that can be easily optimized, saved, and loaded with little to no boilerplate.
Unlike other frameworks, a module does not force the declaration of the forward pass, leaving it up to the implementer to decide how it should be defined.
Additionally, most modules are created using a (de)serializable configuration, which defines the structure of the module and its hyper-parameters.
Parameters and hyper-parameters are not serialized into the same file and both are normaly necessary to load a module for inference.
Parameters and hyper-parameters are not serialized into the same file and both are normally necessary to load a module for inference.
### Optimization
@ -71,7 +71,7 @@ When performing an optimization step, the adaptor handles the following:
3. Makes sure that the gradient, the tensor, and the optimizer state associated with the current parameter are on the same device.
The device can be different if the state is loaded from disk to restart training.
4. Performs the simple optimizer step using the inner tensor since the operations done by the optimizer should not be tracked in the autodiff graph.
5. Updates the state for the current parameter and returns the updated tensor, making sure it's properly registered into the autodiff graph if gradients are maked as required.
5. Updates the state for the current parameter and returns the updated tensor, making sure it's properly registered into the autodiff graph if gradients are marked as required.
Note that a parameter can still be updated by another process, as is the case with running metrics used in batch norm.
These tensors are still wrapped using the `Param` struct so that they are included in the module's state and given a proper parameter ID, but they are not registered in the autodiff graph.

View File

@ -295,7 +295,7 @@ Compile `scripts/publish.rs` using this command:
rustc scripts/publish.rs --crate-type bin --out-dir scripts
```
## Disclamer
## Disclaimer
Burn is currently in active development, and there will be breaking changes. While any resulting
issues are likely to be easy to fix, there are no guarantees at this stage.

8
_typos.toml Normal file
View File

@ -0,0 +1,8 @@
[default]
extend-ignore-identifiers-re = [
"NdArray*",
"ND"
]
[files]
extend-exclude = ["assets/ModuleSerialization.xml"]

View File

@ -38,7 +38,7 @@ impl Graph {
/// be shared with other graphs, therefore they are going to be cleared.
///
/// This is usefull, since the graph is supposed to be consumed only once for backprop, and
/// keeping all the tensors alive for multiple backward call is a heavy waste of ressources.
/// keeping all the tensors alive for multiple backward call is a heavy waste of resources.
pub fn steps(self) -> NodeSteps {
let mut map_drain = HashMap::new();
self.execute_mut(|map| {

View File

@ -11,7 +11,7 @@ use std::marker::PhantomData;
/// Operation in preparation.
///
/// There are 3 diffent modes: 'Init', 'Tracked' and 'UnTracked'.
/// There are 3 different modes: 'Init', 'Tracked' and 'UnTracked'.
/// Each mode has its own set of functions to minimize cloning for unused backward states.
#[derive(new)]
pub struct OpsPrep<Backward, B, S, const D: usize, const N: usize, Mode = Init> {

View File

@ -4,7 +4,7 @@ mod tests {
use burn_tensor::{Data, Distribution, Int, Shape, Tensor};
#[test]
fn should_handle_broacast_during_backward() {
fn should_handle_broadcast_during_backward() {
let x: Tensor<TestADBackend, 2> = Tensor::from_data(
Tensor::<TestADBackend, 1, Int>::arange(0..6)
.into_data()

View File

@ -137,7 +137,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
/// # Note
///
/// Don't use this function after an update on the same thread where other threads might have to
/// register their update before the actual synchonization needs to happen.
/// register their update before the actual synchronization needs to happen.
pub fn value_sync(&self) -> Tensor<B, D> {
let thread_id = get_thread_current_id();
let mut map = self.values.lock().unwrap();

View File

@ -35,7 +35,7 @@ pub struct GeneratePaddingMask<B: Backend> {
pub fn generate_padding_mask<B: Backend>(
pad_token: usize,
tokens_list: Vec<Vec<usize>>,
max_seq_lenght: Option<usize>,
max_seq_length: Option<usize>,
device: &B::Device,
) -> GeneratePaddingMask<B> {
let mut max_size = 0;
@ -46,9 +46,9 @@ pub fn generate_padding_mask<B: Backend>(
max_size = tokens.len();
}
if let Some(max_seq_lenght) = max_seq_lenght {
if tokens.len() >= max_seq_lenght {
max_size = max_seq_lenght;
if let Some(max_seq_length) = max_seq_length {
if tokens.len() >= max_seq_length {
max_size = max_seq_length;
break;
}
}
@ -61,9 +61,9 @@ pub fn generate_padding_mask<B: Backend>(
let mut seq_length = tokens.len();
let mut tokens = tokens;
if let Some(max_seq_lenght) = max_seq_lenght {
if seq_length > max_seq_lenght {
seq_length = max_seq_lenght;
if let Some(max_seq_length) = max_seq_length {
if seq_length > max_seq_length {
seq_length = max_seq_length;
let _ = tokens.split_off(seq_length);
}
}

View File

@ -392,7 +392,7 @@ mod tests {
let output_1 = mha.forward(input_1);
let output_2 = mha.forward(input_2);
// Check that the begginning of each tensor is the same
// Check that the beginning of each tensor is the same
output_1
.context
.slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])

View File

@ -150,7 +150,7 @@ impl<B: Backend> Lstm<B> {
let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate);
let add_values = activation::sigmoid(biased_ig_input_sum);
// o(utput)g(ate) tensors
// o(output)g(ate) tensors
let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate);
let output_values = activation::sigmoid(biased_og_input_sum);

View File

@ -34,7 +34,7 @@ pub struct JsonGzFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
}
/// File recorder using [pretty json format](serde_json) for easy redability.
/// File recorder using [pretty json format](serde_json) for easy readability.
#[derive(new, Debug, Default, Clone)]
pub struct PrettyJsonFileRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,

View File

@ -7,7 +7,7 @@ use serde::{de::DeserializeOwned, Serialize};
///
/// # Notes
///
/// This is especialy useful in no_std environment where weights are stored directly in
/// This is especially useful in no_std environment where weights are stored directly in
/// compiled binaries.
pub trait BytesRecorder:
Recorder<RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = Vec<u8>>

View File

@ -115,7 +115,7 @@ impl<B: Backend, const D: usize> Record for Param<Tensor<B, D>> {
}
}
// Type that can be serialized as is without any convertion.
// Type that can be serialized as is without any conversion.
macro_rules! primitive {
($type:ty) => {
impl Record for $type {

View File

@ -27,7 +27,7 @@ pub enum TestEnumConfig {
#[cfg(feature = "std")]
#[test]
fn struct_config_should_impl_serde() {
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
let file_path = "/tmp/test_struct_config.json";
config.save(file_path).unwrap();
@ -38,13 +38,13 @@ fn struct_config_should_impl_serde() {
#[test]
fn struct_config_should_impl_clone() {
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
assert_eq!(config, config.clone());
}
#[test]
fn struct_config_should_impl_display() {
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
assert_eq!(burn::config::config_to_json(&config), config.to_string());
}
@ -75,7 +75,7 @@ fn enum_config_one_value_should_impl_serde() {
#[cfg(feature = "std")]
#[test]
fn enum_config_multiple_values_should_impl_serde() {
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
let file_path = "/tmp/test_enum_multiple_values_config.json";
config.save(file_path).unwrap();
@ -86,19 +86,19 @@ fn enum_config_multiple_values_should_impl_serde() {
#[test]
fn enum_config_should_impl_clone() {
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
assert_eq!(config, config.clone());
}
#[test]
fn enum_config_should_impl_display() {
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
assert_eq!(burn::config::config_to_json(&config), config.to_string());
}
#[test]
fn struct_config_can_load_binary() {
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
let binary = config_to_json(&config).as_bytes().to_vec();

View File

@ -8,7 +8,7 @@ fn speech_command() {
let item = test.get(index).unwrap();
println!("Item: {:?}", item);
println!("Item Lengh: {:?}", item.audio_samples.len());
println!("Item Length: {:?}", item.audio_samples.len());
println!("Label: {}", item.label.to_string());
assert_eq!(test.len(), 4890);

View File

@ -576,7 +576,7 @@ where
///
/// * A `Result` which is `Ok` if the table could be created, `Err` otherwise.
///
/// TODO (@antimora): add support creating a table with columns coresponding to the item fields
/// TODO (@antimora): add support creating a table with columns corresponding to the item fields
fn create_table(&self, split: &str) -> Result<()> {
// Check if the split already exists
if self.splits.read().unwrap().contains(split) {

View File

@ -50,7 +50,7 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
dataset = dataset.flatten()
# Rename columns to remove dots from the names
dataset = rename_colums(dataset)
dataset = rename_columns(dataset)
print(f"Saving dataset: {name} - {key}")
print(f"Dataset features: {dataset.features}")
@ -81,7 +81,7 @@ def disable_decoding(dataset):
return dataset
def rename_colums(dataset):
def rename_columns(dataset):
"""
Rename columns to remove dots from the names. Dots appear in the column names because of the flattening.
Dots are not allowed in column names in rust and sql (unless quoted). So we replace them with underscores.

View File

@ -85,7 +85,7 @@ pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {
}
}
syn::Data::Enum(_) => panic!("Only struct can be derived"),
syn::Data::Union(_) => panic!("Only struct cna be derived"),
syn::Data::Union(_) => panic!("Only struct can be derived"),
};
fields
}

View File

@ -18,6 +18,6 @@ pub mod onnx;
/// The module for generating the burn code.
pub mod burn;
mod formater;
mod formatter;
mod logger;
pub use formater::*;
pub use formatter::*;

View File

@ -204,7 +204,7 @@ message NodeProto {
repeated string output = 2; // namespace Value
// An optional identifier for this node in a graph.
// This field MAY be absent in ths version of the IR.
// This field MAY be absent in this version of the IR.
string name = 3; // namespace Node
// The symbolic identifier of the Operator to execute.
@ -403,7 +403,7 @@ message ModelProto {
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
// or standard opserator sets are given higher priority or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones

View File

@ -1,4 +1,4 @@
// Orginally copied from the burn/examples/mnist package
// Originally copied from the burn/examples/mnist package
use burn::{
config::Config,

View File

@ -1,4 +1,4 @@
// Orginally copied from the burn/examples/mnist package
// Originally copied from the burn/examples/mnist package
use alloc::vec::Vec;

View File

@ -1,4 +1,4 @@
// Orginally copied from the burn/examples/mnist package
// Originally copied from the burn/examples/mnist package
use crate::{
conv::{ConvBlock, ConvBlockConfig},
@ -52,11 +52,11 @@ impl<B: Backend> Model<B> {
}
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch_size, heigth, width] = input.dims();
let [batch_size, height, width] = input.dims();
let x = input.reshape([batch_size, 1, heigth, width]).detach();
let x = input.reshape([batch_size, 1, height, width]).detach();
let x = self.conv.forward(x);
let x = x.reshape([batch_size, heigth * width]);
let x = x.reshape([batch_size, height * width]);
let x = self.input.forward(x);
let x = self.mlp.forward(x);

View File

@ -32,7 +32,7 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
/// Create a tensor that was created from an operation executed on a parent tensor.
///
/// If the child tensor shared the same storage as its parent, it will be cloned, effectivly
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
/// tracking how much tensors point to the same memory space.
pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
let storage_child = tensor.data_ptr();

View File

@ -19,7 +19,7 @@ This library provides multiple tensor implementations hidden behind an easy to u
### Backends
For now, only two backends are implementated, but adding new ones should not be that hard.
For now, only two backends are implemented, but adding new ones should not be that hard.
* [X] Pytorch using [tch-rs](https://github.com/LaurentMazare/tch-rs)
* [X] 100% Rust backend using [ndarray](https://github.com/rust-ndarray/ndarray)
@ -33,7 +33,7 @@ For now, only two backends are implementated, but adding new ones should not be
Automatic differentiation is implemented as just another tensor backend without any global state.
It's possible since we keep track of the order in which each operation as been executed and the tape is only created when calculating the gradients.
To do so, each operation creates a new node which has a reference to its parent nodes.
Therefore, creating the tape only requires a simple and efficent graph traversal algorithm.
Therefore, creating the tape only requires a simple and efficient graph traversal algorithm.
```rust
let x = ADTensor::from_tensor(x_ndarray);
@ -62,5 +62,5 @@ This crate can be used without the standard library (`#![no_std]`) with `alloc`
the default `std` feature.
* `std` - enables the standard library.
* `burn-tensor-testgen` - enables test macros for genarating tensor tests.
* `burn-tensor-testgen` - enables test macros for generating tensor tests.

View File

@ -278,7 +278,7 @@ impl TensorCheck {
.details(
format!(
"The ranges array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {n_dims_tensor}, ranges array lenght {n_dims_ranges}."
Tensor number of dimensions: {n_dims_tensor}, ranges array length {n_dims_ranges}."
)));
}
@ -334,7 +334,7 @@ impl TensorCheck {
.details(
format!(
"The ranges array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {D1}, ranges array lenght {D2}."
Tensor number of dimensions: {D1}, ranges array length {D2}."
)));
}
@ -510,7 +510,7 @@ impl TensorCheck {
}
/// The goal is to minimize the cost of checks when there are no error, but it's way less
/// important when an error occured, crafting a comprehensive error message is more important
/// important when an error occurred, crafting a comprehensive error message is more important
/// than optimizing string manipulation.
fn register(self, ops: &str, error: TensorError) -> Self {
let errors = match self {
@ -634,7 +634,7 @@ impl TensorError {
}
/// We use a macro for all checks, since the panic message file and line number will match the
/// function that does the check instead of a the generic error.rs crate private unreleated file
/// function that does the check instead of a the generic error.rs crate private unrelated file
/// and line number.
#[macro_export(local_inner_macros)]
macro_rules! check {

View File

@ -250,7 +250,7 @@ where
/// Detach the current tensor from the autodiff graph.
/// This function does nothing when autodiff is not enabled.
/// This can be used in batchers or elsewere to ensure that previous operations are not
/// This can be used in batchers or elsewhere to ensure that previous operations are not
/// considered in the autodiff graph.
pub fn detach(self) -> Self {
Self::new(B::detach(self.primitive))

View File

@ -35,7 +35,7 @@ where
Self::new(K::add_scalar(self.primitive, other))
}
/// Applies element wise substraction operation.
/// Applies element wise subtraction operation.
///
/// `y = x2 - x1`
#[allow(clippy::should_implement_trait)]
@ -44,7 +44,7 @@ where
Self::new(K::sub(self.primitive, other.primitive))
}
/// Applies element wise substraction operation with a scalar.
/// Applies element wise subtraction operation with a scalar.
///
/// `y = x - s`
pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
@ -238,7 +238,7 @@ where
///
/// # Notes
///
/// The index tensor shoud have the same shape as the original tensor except for the dim
/// The index tensor should have the same shape as the original tensor except for the dim
/// specified.
pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
check!(TensorCheck::gather::<D>(
@ -250,7 +250,7 @@ where
Self::new(K::gather(dim, self.primitive, indices))
}
/// Assign the gathered elements corresponding to the given indices along the speficied dimension
/// Assign the gathered elements corresponding to the given indices along the specified dimension
/// from the value tensor to the original tensor using sum reduction.
///
/// Example using a 3D tensor:
@ -261,7 +261,7 @@ where
///
/// # Notes
///
/// The index tensor shoud have the same shape as the original tensor except for the speficied
/// The index tensor should have the same shape as the original tensor except for the specified
/// dimension. The value and index tensors should have the same shape.
///
/// Other references to the input tensor will not be modified by this operation.

View File

@ -326,7 +326,7 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
let tolerance = libm::pow(0.1, precision as f64);
if err > tolerance {
// Only print the first 5 differents values.
// Only print the first 5 different values.
if num_diff < max_num_diff {
message += format!(
"\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}"

View File

@ -195,8 +195,8 @@ where
LR::Record: 'static,
{
self.init_logger();
let callack = Box::new(self.dashboard);
let callback = Box::new(AsyncTrainerCallback::new(callack));
let callback = Box::new(self.dashboard);
let callback = Box::new(AsyncTrainerCallback::new(callback));
let checkpointer_optimizer = match self.checkpointer_optimizer {
Some(checkpointer) => {

View File

@ -54,7 +54,7 @@ pub trait Metric: Send + Sync {
/// Adaptor are used to transform types so that they can be used by metrics.
///
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
/// registed with the [leaner buidler](crate::learner::LearnerBuilder) .
/// registered with the [leaner buidler](crate::learner::LearnerBuilder) .
pub trait Adaptor<T> {
/// Adapt the type to be passed to a [metric](Metric).
fn adapt(&self) -> T;

View File

@ -21,7 +21,7 @@ pub struct TrainingProgress {
}
impl TrainingProgress {
/// Creates a new empy training progress.
/// Creates a new empty training progress.
pub fn none() -> Self {
Self {
progress: Progress {

View File

@ -83,9 +83,9 @@ impl Context {
/// # Notes
///
/// This function isn't safe, buffer can be mutated by the GPU. The users must ensure that a
/// buffer can be mutated when lauching a compute shaders with write access to a buffer.
/// buffer can be mutated when launching a compute shaders with write access to a buffer.
///
/// Buffer positions are used as bindings when lauching a compute kernel.
/// Buffer positions are used as bindings when launching a compute kernel.
pub fn execute(
&self,
work_group: WorkGroup,

View File

@ -22,7 +22,7 @@ pub trait ContextServer {
fn start(device: Arc<wgpu::Device>, queue: wgpu::Queue) -> Self::Client;
}
/// Context server where each operation is added in a synchonous maner.
/// Context server where each operation is added in a synchronous maner.
#[derive(Debug)]
pub struct SyncContextServer {
device: Arc<wgpu::Device>,
@ -141,7 +141,7 @@ impl SyncContextServer {
fn submit(&mut self) {
assert!(
self.tasks.is_empty(),
"Tasks should be completed before submiting the current encoder."
"Tasks should be completed before submitting the current encoder."
);
let mut new_encoder = self
.device

View File

@ -10,7 +10,7 @@ use crate::{
use super::base::empty_from_context;
// Output of the pad_round function. Allows to know explicitly if early return occured
// Output of the pad_round function. Allows to know explicitly if early return occurred
pub(super) enum PaddingOutput<E: WgpuElement, const D: usize> {
Padded(WgpuTensor<E, D>),
Unchanged(WgpuTensor<E, D>),

View File

@ -3,7 +3,7 @@ use burn_tensor::Shape;
use std::sync::Arc;
use wgpu::Buffer;
/// Build basic info to lauch pool 2d kernels.
/// Build basic info to launch pool 2d kernels.
pub fn build_output_and_info_pool2d<E: WgpuElement>(
x: &WgpuTensor<E, 4>,
kernel_size: [usize; 2],

View File

@ -101,7 +101,7 @@
isDrawingMode: true,
});
const backgroundColor = "rgba(255, 255, 255, 255)"; // White with solid alha
const backgroundColor = "rgba(255, 255, 255, 255)"; // White with solid alpha
fabricCanvas.freeDrawingBrush.width = 25;
fabricCanvas.backgroundColor = backgroundColor;

View File

@ -1,5 +1,5 @@
# Openning index.html file directly by a browser does not work because of
# Opening index.html file directly by a browser does not work because of
# the security restrictions by the browser. Viewing the HTML file will fail with
# this error message:

View File

@ -1,6 +1,6 @@
#![allow(clippy::new_without_default)]
// Orginally copied from the burn/examples/mnist package
// Originally copied from the burn/examples/mnist package
use burn::{
module::Module,
@ -48,15 +48,15 @@ impl<B: Backend> Model<B> {
}
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch_size, heigth, width] = input.dims();
let [batch_size, height, width] = input.dims();
let x = input.reshape([batch_size, 1, heigth, width]).detach();
let x = input.reshape([batch_size, 1, height, width]).detach();
let x = self.conv1.forward(x);
let x = self.conv2.forward(x);
let x = self.conv3.forward(x);
let [batch_size, channels, heigth, width] = x.dims();
let x = x.reshape([batch_size, channels * heigth * width]);
let [batch_size, channels, height, width] = x.dims();
let x = x.reshape([batch_size, channels * height * width]);
let x = self.dropout.forward(x);
let x = self.fc1.forward(x);

View File

@ -50,15 +50,15 @@ impl<B: Backend> Model<B> {
}
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch_size, heigth, width] = input.dims();
let [batch_size, height, width] = input.dims();
let x = input.reshape([batch_size, 1, heigth, width]).detach();
let x = input.reshape([batch_size, 1, height, width]).detach();
let x = self.conv1.forward(x);
let x = self.conv2.forward(x);
let x = self.conv3.forward(x);
let [batch_size, channels, heigth, width] = x.dims();
let x = x.reshape([batch_size, channels * heigth * width]);
let [batch_size, channels, height, width] = x.dims();
let x = x.reshape([batch_size, channels * height * width]);
let x = self.dropout.forward(x);
let x = self.fc1.forward(x);

View File

@ -24,7 +24,7 @@ pub fn run<B: Backend>() {
//
// mismatched types
// expected reference `&NamedTensor<B, (Batch, DModel, _)>`
// found reference `&NamedTensor<B, (Batch, SeqLenght, DModel)>`
// found reference `&NamedTensor<B, (Batch, SeqLength, DModel)>`
// let output = weights.matmul(&input);
let output = input.clone().matmul(weights.clone());
@ -32,7 +32,7 @@ pub fn run<B: Backend>() {
// Doesn't compile
//
// mismatched types
// expected reference `&NamedTensor<B, (Batch, SeqLenght, DModel)>`
// expected reference `&NamedTensor<B, (Batch, SeqLength, DModel)>`
// found reference `&NamedTensor<B, (Batch, DModel, DModel)>`
// let output = output.mul(&weights);

View File

@ -23,7 +23,7 @@ use std::sync::Arc;
pub struct TextClassificationBatcher<B: Backend> {
tokenizer: Arc<dyn Tokenizer>, // Tokenizer for converting text to token IDs
device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device)
max_seq_lenght: usize, // Maximum sequence length for tokenized text
max_seq_length: usize, // Maximum sequence length for tokenized text
}
/// Struct for training batch in text classification task
@ -60,7 +60,7 @@ impl<B: Backend> Batcher<TextClassificationItem, TextClassificationTrainingBatch
let mask = generate_padding_mask(
self.tokenizer.pad_token(),
tokens_list,
Some(self.max_seq_lenght),
Some(self.max_seq_length),
&B::Device::default(),
);
@ -90,7 +90,7 @@ impl<B: Backend> Batcher<String, TextClassificationInferenceBatch<B>>
let mask = generate_padding_mask(
self.tokenizer.pad_token(),
tokens_list,
Some(self.max_seq_lenght),
Some(self.max_seq_length),
&B::Device::default(),
);

View File

@ -9,7 +9,7 @@ use std::sync::Arc;
#[derive(new)]
pub struct TextGenerationBatcher {
tokenizer: Arc<dyn Tokenizer>,
max_seq_lenght: usize,
max_seq_length: usize,
}
#[derive(Debug, Clone, new)]
@ -36,7 +36,7 @@ impl<B: Backend> Batcher<TextGenerationItem, TextGenerationBatch<B>> for TextGen
let mask = generate_padding_mask(
self.tokenizer.pad_token(),
tokens_list,
Some(self.max_seq_lenght),
Some(self.max_seq_length),
&B::Device::default(),
);