fix: dataset + optim

This commit is contained in:
nathaniel 2022-10-10 11:09:09 -04:00
parent f2f4fa8a92
commit 31d512ed8f
4 changed files with 39 additions and 21 deletions

View File

@ -130,19 +130,30 @@ fn download(
command.arg(split);
}
let mut extracted_raw = Vec::new();
let mut extracted_images = Vec::new();
for extractor in extractors {
match extractor {
Extractor::Raw(field) => {
command.arg("--extract-raw");
command.arg(field);
}
Extractor::Image(field) => {
command.arg("--extract-image");
command.arg(field);
}
Extractor::Raw(field) => extracted_raw.push(field),
Extractor::Image(field) => extracted_images.push(field),
};
}
if !extracted_raw.is_empty() {
command.arg("--extract-raw");
for field in extracted_raw {
command.arg(field);
}
}
if !extracted_images.is_empty() {
command.arg("--extract-image");
for field in extracted_images {
command.arg(field);
}
}
if !config.is_empty() {
command.arg("--config");
for config in config {

View File

@ -47,11 +47,11 @@ impl<const D: usize, B: Backend> BackwardRecordedOps<B::TensorPrimitive<D>>
let indexes: Vec<_> = grad.shape().dims.iter().map(|v| 0..*v).collect();
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();
for (i, node) in self.nodes.iter().enumerate() {
self.nodes.iter().enumerate().for_each(|(i, node)| {
let mut indexes = indexes.clone();
indexes[self.dim] = i..i + 1;
node.state.update_grad(grad.index(indexes));
}
});
}
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {

View File

@ -360,6 +360,12 @@ where
Self::new(tensor)
}
/// Create a tensor of the given shape where each element is zero.
pub fn zeros_device(shape: Shape<D>, device: B::Device) -> Self {
let tensor = B::zeros(shape, device);
Self::new(tensor)
}
/// Create a tensor of the given shape where each element is one.
pub fn ones(shape: Shape<D>) -> Self {
let tensor = B::ones(shape, B::Device::default());

View File

@ -54,7 +54,7 @@ impl<B: ADBackend> Optimizer for Sgd<B> {
tensor: &mut Tensor<B, D>,
grads: &Gradients,
) {
let grad = tensor.grad(grads).unwrap();
if let Some(grad) = tensor.grad(grads) {
let grad = match &mut self.weight_decay {
Some(weight_decay) => weight_decay.transform(id, grad),
None => grad,
@ -67,6 +67,7 @@ impl<B: ADBackend> Optimizer for Sgd<B> {
let delta = grad.mul_scalar(self.learning_rate);
tensor.update(tensor.inner() - delta);
}
}
fn register_param_state<const D: usize>(&self, id: &ParamId, state: &mut StateNamed<B::Elem>) {
if let Some(momentum) = &self.momentum {