fixups
This commit is contained in:
parent
0faa82da98
commit
1836a758f8
|
@ -1,5 +1,4 @@
|
||||||
|
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
||||||
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
|
||||||
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
|
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
|
||||||
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
|
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
|
||||||
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
|
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
|
||||||
|
@ -48,7 +47,7 @@
|
||||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
|
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
|
||||||
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_diff_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff.
|
# the file from the diff.
|
||||||
|
@ -72,7 +71,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,15 +23,16 @@ import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaForSequenceClassification,
|
LlamaForSequenceClassification,
|
||||||
LlamaModel,
|
|
||||||
LlamaForTokenClassification,
|
LlamaForTokenClassification,
|
||||||
|
LlamaModel,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache
|
from ...cache_utils import Cache
|
||||||
from ...modeling_outputs import CausalLMOutputWithPast
|
from ...modeling_outputs import CausalLMOutputWithPast
|
||||||
|
@ -109,6 +110,7 @@ class GemmaConfig(PretrainedConfig):
|
||||||
>>> # Accessing the model configuration
|
>>> # Accessing the model configuration
|
||||||
>>> configuration = model.config
|
>>> configuration = model.config
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
model_type = "gemma"
|
model_type = "gemma"
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
@ -161,6 +163,7 @@ class GemmaConfig(PretrainedConfig):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Example where we only want to overwrite the defaults of an init?
|
# Example where we only want to overwrite the defaults of an init?
|
||||||
class GemmaConfig(LlamaConfig):
|
class GemmaConfig(LlamaConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -188,6 +191,7 @@ class GemmaConfig(LlamaConfig):
|
||||||
):
|
):
|
||||||
super().__init__(self)
|
super().__init__(self)
|
||||||
|
|
||||||
|
|
||||||
class GemmaRMSNorm(nn.Module):
|
class GemmaRMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -407,6 +411,7 @@ class GemmaModel(LlamaModel):
|
||||||
cache_position,
|
cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Example where we ony modify the docstring and call super
|
# Example where we ony modify the docstring and call super
|
||||||
class GemmaForCausalLM(LlamaForCausalLM):
|
class GemmaForCausalLM(LlamaForCausalLM):
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -459,12 +464,13 @@ class GemmaForCausalLM(LlamaForCausalLM):
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
return_dict,
|
return_dict,
|
||||||
cache_position,
|
cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GemmaForSequenceClassification(LlamaForSequenceClassification):
|
class GemmaForSequenceClassification(LlamaForSequenceClassification):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GemmaForTokenClassification(LlamaForTokenClassification):
|
class GemmaForTokenClassification(LlamaForTokenClassification):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
|
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
||||||
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
|
||||||
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
|
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
|
||||||
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
|
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
|
||||||
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
|
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
|
||||||
|
@ -48,7 +47,7 @@
|
||||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
|
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
|
||||||
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_diff_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff.
|
# the file from the diff.
|
||||||
|
@ -137,6 +136,7 @@ def _get_unpad_data(attention_mask):
|
||||||
max_seqlen_in_batch,
|
max_seqlen_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GemmaRMSNorm(nn.Module):
|
class GemmaRMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
|
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
||||||
# ░░▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒░░
|
|
||||||
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
|
# ░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒
|
||||||
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
|
# ░░██░░▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓░░██
|
||||||
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
|
# ░░██░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░██
|
||||||
|
@ -48,7 +47,7 @@
|
||||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██████▒▒░░░░░░██████░░░░░░░░░░██████████████████████████████████████░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░██
|
||||||
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
|
# ░░██░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░▒▒░░░░░░▒▒░░░░░░░░██
|
||||||
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
# ░░████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████░░
|
||||||
# This file was automatically generated from <path_to_diff_file.py>.
|
# This file was automatically generated from <path_to_diff_file.py>.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the diff.
|
# the file from the diff.
|
||||||
|
|
|
@ -74,6 +74,7 @@ AUTO_GENERATED_MESSAGE = """
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_module_source_from_name(module_name: str) -> str:
|
def get_module_source_from_name(module_name: str) -> str:
|
||||||
spec = importlib.util.find_spec(module_name)
|
spec = importlib.util.find_spec(module_name)
|
||||||
if spec is None or spec.origin is None:
|
if spec is None or spec.origin is None:
|
||||||
|
@ -236,7 +237,17 @@ def find_classes_in_file(module, old_id="llama", new_id="gemma"):
|
||||||
wrapper.visit(class_finder)
|
wrapper.visit(class_finder)
|
||||||
return class_finder
|
return class_finder
|
||||||
|
|
||||||
DOCSTRING_NODE = m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString(value=m.MatchIfTrue(lambda value: re.search(r'\"\"\"[\s\S]*\"\"\"',value) is not None)))])
|
|
||||||
|
DOCSTRING_NODE = m.SimpleStatementLine(
|
||||||
|
body=[
|
||||||
|
m.Expr(
|
||||||
|
value=m.SimpleString(
|
||||||
|
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SuperTransformer(cst.CSTTransformer):
|
class SuperTransformer(cst.CSTTransformer):
|
||||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||||
|
@ -347,15 +358,15 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
||||||
|
|
||||||
# TODO here is where we merge stuff from super. We can choose to merge the docstring as well!
|
# TODO here is where we merge stuff from super. We can choose to merge the docstring as well!
|
||||||
# We could also check the docstring here
|
# We could also check the docstring here
|
||||||
original_methods = {f.name.value if hasattr(f,"name") else f: f for f in original_node.body.body }
|
original_methods = {f.name.value if hasattr(f, "name") else f: f for f in original_node.body.body}
|
||||||
|
|
||||||
# Copy methods from original node to replacement node, preserving decorators
|
# Copy methods from original node to replacement node, preserving decorators
|
||||||
updated_methods = {f.name.value if hasattr(f,"name") else f: f for f in updated_node.body.body }
|
updated_methods = {f.name.value if hasattr(f, "name") else f: f for f in updated_node.body.body}
|
||||||
end_meth = []
|
end_meth = []
|
||||||
for name, func in original_methods.items():
|
for name, func in original_methods.items():
|
||||||
if name in updated_methods:
|
if name in updated_methods:
|
||||||
# Replace the method in the replacement class, preserving decorators
|
# Replace the method in the replacement class, preserving decorators
|
||||||
func = func.with_changes(body=updated_methods[name].body, params = updated_methods[name].params )
|
func = func.with_changes(body=updated_methods[name].body, params=updated_methods[name].params)
|
||||||
end_meth.append(func)
|
end_meth.append(func)
|
||||||
|
|
||||||
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
||||||
|
|
Loading…
Reference in New Issue