Fix for "Add Operator Settings #6650" (#6670)

* Fix for "Add Operator Settings #6650"

* Fix lint
This commit is contained in:
Steve Wood 2021-06-30 20:14:41 -04:00 committed by GitHub
parent 6819cc07c1
commit 6c08d4bfe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 3 deletions

View File

@ -189,7 +189,8 @@ class Gradient(GradientBase):
elif isinstance(operator, ListOp):
grad_ops = [self.get_gradient(op, param) for op in operator.oplist]
if operator._combo_fn is None: # If using default
# pylint: disable=comparison-with-callable
if operator.combo_fn == ListOp.default_combo_fn: # If using default
return ListOp(oplist=grad_ops)
elif isinstance(operator, SummedOp):
return SummedOp(oplist=[grad for grad in grad_ops if grad != ~Zero @ One]).reduce()

View File

@ -14,7 +14,7 @@
from functools import reduce
from numbers import Number
from typing import Callable, Dict, Iterator, List, Optional, Set, Sequence, Union, cast
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Sequence, Union, cast
import numpy as np
from scipy.sparse import spmatrix
@ -122,6 +122,11 @@ class ListOp(OperatorBase):
"""
return self._oplist
@staticmethod
def default_combo_fn(x: Any) -> Any:
"""ListOp default combo function i.e. lambda x: x"""
return x
@property
def combo_fn(self) -> Callable:
"""The function defining how to combine ``oplist`` (or Numbers, or NumPy arrays) to
@ -132,7 +137,7 @@ class ListOp(OperatorBase):
The combination function.
"""
if self._combo_fn is None:
return lambda x: x
return ListOp.default_combo_fn
return self._combo_fn
@property