final state
This commit is contained in:
parent
513b933b60
commit
751c4dbdfd
|
@ -13,8 +13,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Gemma model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
@ -54,7 +53,6 @@ from .configuration_gemma import GemmaConfig
|
|||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
"""PyTorch Gemma model."""
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -728,7 +726,6 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
|||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "GemmaConfig"
|
||||
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ from ...utils import (
|
|||
)
|
||||
from .configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
|
|
@ -349,17 +349,19 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
return node
|
||||
|
||||
def leave_SimpleStatementLine(self, original_node, updated_node):
|
||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
||||
self.all_imports.append(updated_node)
|
||||
return updated_node
|
||||
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
|
||||
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
|
||||
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
|
||||
return cst.RemoveFromParent()
|
||||
self.all_imports.append(updated_node)
|
||||
return updated_node
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
|
||||
if parent_node not in self.all_imports:
|
||||
self.all_imports.append(updated_node)
|
||||
return updated_node
|
||||
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
|
||||
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
|
||||
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
|
||||
return cst.RemoveFromParent()
|
||||
if parent_node not in self.all_imports:
|
||||
self.all_imports.append(updated_node)
|
||||
return updated_node
|
||||
self.global_scope_index += 100
|
||||
self.new_body[self.python_module.code_for_node(updated_node.body[0])] = {
|
||||
"insert_idx": self.global_scope_index,
|
||||
|
@ -451,7 +453,7 @@ def convert_file(diff_file, cst_transformers=None):
|
|||
if cst_transformers is None:
|
||||
cst_transformers = DiffConverterTransformer(module)
|
||||
new_mod = wrapper.visit(cst_transformers)
|
||||
ruffed_code = new_mod.code # run_ruff(new_mod.code, True)
|
||||
ruffed_code = run_ruff(new_mod.code, True)
|
||||
if len(ruffed_code) > 0:
|
||||
with open(diff_file.replace("diff_", "modeling_"), "w") as f:
|
||||
f.write(AUTO_GENERATED_MESSAGE + ruffed_code)
|
||||
|
|
Loading…
Reference in New Issue