mirror of https://github.com/tracel-ai/burn.git
fix: dataset + optim
This commit is contained in:
parent
f2f4fa8a92
commit
31d512ed8f
|
@ -130,19 +130,30 @@ fn download(
|
|||
command.arg(split);
|
||||
}
|
||||
|
||||
let mut extracted_raw = Vec::new();
|
||||
let mut extracted_images = Vec::new();
|
||||
|
||||
for extractor in extractors {
|
||||
match extractor {
|
||||
Extractor::Raw(field) => {
|
||||
command.arg("--extract-raw");
|
||||
command.arg(field);
|
||||
}
|
||||
Extractor::Image(field) => {
|
||||
command.arg("--extract-image");
|
||||
command.arg(field);
|
||||
}
|
||||
Extractor::Raw(field) => extracted_raw.push(field),
|
||||
Extractor::Image(field) => extracted_images.push(field),
|
||||
};
|
||||
}
|
||||
|
||||
if !extracted_raw.is_empty() {
|
||||
command.arg("--extract-raw");
|
||||
for field in extracted_raw {
|
||||
command.arg(field);
|
||||
}
|
||||
}
|
||||
|
||||
if !extracted_images.is_empty() {
|
||||
command.arg("--extract-image");
|
||||
for field in extracted_images {
|
||||
command.arg(field);
|
||||
}
|
||||
}
|
||||
|
||||
if !config.is_empty() {
|
||||
command.arg("--config");
|
||||
for config in config {
|
||||
|
|
|
@ -47,11 +47,11 @@ impl<const D: usize, B: Backend> BackwardRecordedOps<B::TensorPrimitive<D>>
|
|||
let indexes: Vec<_> = grad.shape().dims.iter().map(|v| 0..*v).collect();
|
||||
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();
|
||||
|
||||
for (i, node) in self.nodes.iter().enumerate() {
|
||||
self.nodes.iter().enumerate().for_each(|(i, node)| {
|
||||
let mut indexes = indexes.clone();
|
||||
indexes[self.dim] = i..i + 1;
|
||||
node.state.update_grad(grad.index(indexes));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
|
||||
|
|
|
@ -360,6 +360,12 @@ where
|
|||
Self::new(tensor)
|
||||
}
|
||||
|
||||
/// Create a tensor of the given shape where each element is zero.
|
||||
pub fn zeros_device(shape: Shape<D>, device: B::Device) -> Self {
|
||||
let tensor = B::zeros(shape, device);
|
||||
Self::new(tensor)
|
||||
}
|
||||
|
||||
/// Create a tensor of the given shape where each element is one.
|
||||
pub fn ones(shape: Shape<D>) -> Self {
|
||||
let tensor = B::ones(shape, B::Device::default());
|
||||
|
|
|
@ -54,7 +54,7 @@ impl<B: ADBackend> Optimizer for Sgd<B> {
|
|||
tensor: &mut Tensor<B, D>,
|
||||
grads: &Gradients,
|
||||
) {
|
||||
let grad = tensor.grad(grads).unwrap();
|
||||
if let Some(grad) = tensor.grad(grads) {
|
||||
let grad = match &mut self.weight_decay {
|
||||
Some(weight_decay) => weight_decay.transform(id, grad),
|
||||
None => grad,
|
||||
|
@ -67,6 +67,7 @@ impl<B: ADBackend> Optimizer for Sgd<B> {
|
|||
let delta = grad.mul_scalar(self.learning_rate);
|
||||
tensor.update(tensor.inner() - delta);
|
||||
}
|
||||
}
|
||||
|
||||
fn register_param_state<const D: usize>(&self, id: &ParamId, state: &mut StateNamed<B::Elem>) {
|
||||
if let Some(momentum) = &self.momentum {
|
||||
|
|
Loading…
Reference in New Issue