Fix a shape annotation and typos in `mamba` slow forward (#30691)

* fix typos and one shape comment

* fix `intermediade` typo in jamba
This commit is contained in:
Anton Vlasjuk 2024-05-20 13:55:57 +02:00 committed by GitHub
parent e6708709cb
commit 76e05301c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 10 deletions

View File

@ -962,15 +962,15 @@ class JambaMambaMixer(nn.Module):
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size] discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
deltaB_u = discrete_B * hidden_states[:, :, :, None].float() deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
# 3.c perform the recurrence y ← SSM(A, B, C)(x) # 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = [] scan_outputs = []
for i in range(seq_len): for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
scan_outputs.append(scan_output[:, :, 0]) scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len] scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate)) scan_output = (scan_output * self.act(gate))
@ -978,7 +978,7 @@ class JambaMambaMixer(nn.Module):
cache_params.ssm_states[self.layer_idx] = ssm_state cache_params.ssm_states[self.layer_idx] = ssm_state
# 4. Final linear projection # 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
return contextualized_states return contextualized_states
# fmt: on # fmt: on

View File

@ -279,16 +279,16 @@ class MambaMixer(nn.Module):
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size] discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
deltaB_u = discrete_B * hidden_states[:, :, :, None].float() deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
# 3.c perform the recurrence y ← SSM(A, B, C)(x) # 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = [] scan_outputs = []
for i in range(seq_len): for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
scan_outputs.append(scan_output[:, :, 0]) scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate)) scan_output = (scan_output * self.act(gate))
@ -296,7 +296,7 @@ class MambaMixer(nn.Module):
cache_params.ssm_states[self.layer_idx].copy_(ssm_state) cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
# 4. Final linear projection # 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
return contextualized_states return contextualized_states
# fmt: on # fmt: on