Fix RecurrentGemma device_map (#30273)
* Switch to non persistant buffer * fix device mismatch issue due to cache * style
This commit is contained in:
parent
9459efb807
commit
7509a0ad98
|
@ -252,7 +252,7 @@ class RecurrentGemmaSdpaAttention(nn.Module):
|
|||
to_shift = cache_position >= self.config.attention_window_size - 1
|
||||
indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size
|
||||
|
||||
k_out, v_out = self.key_states, self.value_states
|
||||
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
|
||||
k_out = k_out[:, :, indices]
|
||||
v_out = v_out[:, :, indices]
|
||||
|
||||
|
@ -376,7 +376,9 @@ class RecurrentGemmaRglru(nn.Module):
|
|||
return hidden_states, hidden_states[:, 0].type(acc_dtype)
|
||||
|
||||
else:
|
||||
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None]
|
||||
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to(
|
||||
recurrent_gate.device
|
||||
)
|
||||
contextualized_states += hidden_states.type(acc_dtype)
|
||||
return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1]
|
||||
|
||||
|
@ -387,7 +389,7 @@ class RecurrentGemmaRglru(nn.Module):
|
|||
|
||||
contextualized_states = torch.zeros_like(hidden_states)
|
||||
for t in range(hidden_states.shape[1]):
|
||||
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states
|
||||
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device)
|
||||
recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype)
|
||||
contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype)
|
||||
|
||||
|
@ -658,7 +660,9 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
|
|||
self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.register_buffer("normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16))
|
||||
self.register_buffer(
|
||||
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False
|
||||
)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
|
Loading…
Reference in New Issue