Fix RecurrentGemma device_map (#30273)

* Switch to non persistant buffer

* fix device mismatch issue due to cache

* style
This commit is contained in:
Marc Sun 2024-04-18 11:52:10 +02:00 committed by GitHub
parent 9459efb807
commit 7509a0ad98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 4 deletions

View File

@ -252,7 +252,7 @@ class RecurrentGemmaSdpaAttention(nn.Module):
to_shift = cache_position >= self.config.attention_window_size - 1 to_shift = cache_position >= self.config.attention_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size
k_out, v_out = self.key_states, self.value_states k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
k_out = k_out[:, :, indices] k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices] v_out = v_out[:, :, indices]
@ -376,7 +376,9 @@ class RecurrentGemmaRglru(nn.Module):
return hidden_states, hidden_states[:, 0].type(acc_dtype) return hidden_states, hidden_states[:, 0].type(acc_dtype)
else: else:
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None] contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to(
recurrent_gate.device
)
contextualized_states += hidden_states.type(acc_dtype) contextualized_states += hidden_states.type(acc_dtype)
return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1] return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1]
@ -387,7 +389,7 @@ class RecurrentGemmaRglru(nn.Module):
contextualized_states = torch.zeros_like(hidden_states) contextualized_states = torch.zeros_like(hidden_states)
for t in range(hidden_states.shape[1]): for t in range(hidden_states.shape[1]):
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device)
recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype) recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype)
contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype) contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype)
@ -658,7 +660,9 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.register_buffer("normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16)) self.register_buffer(
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False
)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()