mirror of https://github.com/tracel-ai/burn.git
fix: dataset + optim
This commit is contained in:
parent
f2f4fa8a92
commit
31d512ed8f
|
@ -130,19 +130,30 @@ fn download(
|
||||||
command.arg(split);
|
command.arg(split);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut extracted_raw = Vec::new();
|
||||||
|
let mut extracted_images = Vec::new();
|
||||||
|
|
||||||
for extractor in extractors {
|
for extractor in extractors {
|
||||||
match extractor {
|
match extractor {
|
||||||
Extractor::Raw(field) => {
|
Extractor::Raw(field) => extracted_raw.push(field),
|
||||||
command.arg("--extract-raw");
|
Extractor::Image(field) => extracted_images.push(field),
|
||||||
command.arg(field);
|
|
||||||
}
|
|
||||||
Extractor::Image(field) => {
|
|
||||||
command.arg("--extract-image");
|
|
||||||
command.arg(field);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !extracted_raw.is_empty() {
|
||||||
|
command.arg("--extract-raw");
|
||||||
|
for field in extracted_raw {
|
||||||
|
command.arg(field);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !extracted_images.is_empty() {
|
||||||
|
command.arg("--extract-image");
|
||||||
|
for field in extracted_images {
|
||||||
|
command.arg(field);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !config.is_empty() {
|
if !config.is_empty() {
|
||||||
command.arg("--config");
|
command.arg("--config");
|
||||||
for config in config {
|
for config in config {
|
||||||
|
|
|
@ -47,11 +47,11 @@ impl<const D: usize, B: Backend> BackwardRecordedOps<B::TensorPrimitive<D>>
|
||||||
let indexes: Vec<_> = grad.shape().dims.iter().map(|v| 0..*v).collect();
|
let indexes: Vec<_> = grad.shape().dims.iter().map(|v| 0..*v).collect();
|
||||||
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();
|
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();
|
||||||
|
|
||||||
for (i, node) in self.nodes.iter().enumerate() {
|
self.nodes.iter().enumerate().for_each(|(i, node)| {
|
||||||
let mut indexes = indexes.clone();
|
let mut indexes = indexes.clone();
|
||||||
indexes[self.dim] = i..i + 1;
|
indexes[self.dim] = i..i + 1;
|
||||||
node.state.update_grad(grad.index(indexes));
|
node.state.update_grad(grad.index(indexes));
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
|
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
|
||||||
|
|
|
@ -360,6 +360,12 @@ where
|
||||||
Self::new(tensor)
|
Self::new(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a tensor of the given shape where each element is zero.
|
||||||
|
pub fn zeros_device(shape: Shape<D>, device: B::Device) -> Self {
|
||||||
|
let tensor = B::zeros(shape, device);
|
||||||
|
Self::new(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a tensor of the given shape where each element is one.
|
/// Create a tensor of the given shape where each element is one.
|
||||||
pub fn ones(shape: Shape<D>) -> Self {
|
pub fn ones(shape: Shape<D>) -> Self {
|
||||||
let tensor = B::ones(shape, B::Device::default());
|
let tensor = B::ones(shape, B::Device::default());
|
||||||
|
|
|
@ -54,7 +54,7 @@ impl<B: ADBackend> Optimizer for Sgd<B> {
|
||||||
tensor: &mut Tensor<B, D>,
|
tensor: &mut Tensor<B, D>,
|
||||||
grads: &Gradients,
|
grads: &Gradients,
|
||||||
) {
|
) {
|
||||||
let grad = tensor.grad(grads).unwrap();
|
if let Some(grad) = tensor.grad(grads) {
|
||||||
let grad = match &mut self.weight_decay {
|
let grad = match &mut self.weight_decay {
|
||||||
Some(weight_decay) => weight_decay.transform(id, grad),
|
Some(weight_decay) => weight_decay.transform(id, grad),
|
||||||
None => grad,
|
None => grad,
|
||||||
|
@ -67,6 +67,7 @@ impl<B: ADBackend> Optimizer for Sgd<B> {
|
||||||
let delta = grad.mul_scalar(self.learning_rate);
|
let delta = grad.mul_scalar(self.learning_rate);
|
||||||
tensor.update(tensor.inner() - delta);
|
tensor.update(tensor.inner() - delta);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn register_param_state<const D: usize>(&self, id: &ParamId, state: &mut StateNamed<B::Elem>) {
|
fn register_param_state<const D: usize>(&self, id: &ParamId, state: &mut StateNamed<B::Elem>) {
|
||||||
if let Some(momentum) = &self.momentum {
|
if let Some(momentum) = &self.momentum {
|
||||||
|
|
Loading…
Reference in New Issue