Fix a shape annotation and typos in `mamba` slow forward (#30691)
* fix typos and one shape comment * fix `intermediade` typo in jamba
This commit is contained in:
parent
e6708709cb
commit
76e05301c3
|
@ -962,15 +962,15 @@ class JambaMambaMixer(nn.Module):
|
|||
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
|
||||
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
|
||||
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
|
||||
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size]
|
||||
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
|
||||
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
scan_outputs = []
|
||||
for i in range(seq_len):
|
||||
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
|
||||
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
|
||||
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
|
||||
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
|
||||
scan_outputs.append(scan_output[:, :, 0])
|
||||
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len]
|
||||
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
|
||||
scan_output = scan_output + (hidden_states * self.D[None, :, None])
|
||||
scan_output = (scan_output * self.act(gate))
|
||||
|
||||
|
@ -978,7 +978,7 @@ class JambaMambaMixer(nn.Module):
|
|||
cache_params.ssm_states[self.layer_idx] = ssm_state
|
||||
|
||||
# 4. Final linear projection
|
||||
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
|
||||
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
|
||||
return contextualized_states
|
||||
# fmt: on
|
||||
|
||||
|
|
|
@ -279,16 +279,16 @@ class MambaMixer(nn.Module):
|
|||
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
|
||||
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
|
||||
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
|
||||
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size]
|
||||
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
|
||||
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
|
||||
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
scan_outputs = []
|
||||
for i in range(seq_len):
|
||||
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
|
||||
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
|
||||
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
|
||||
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
|
||||
scan_outputs.append(scan_output[:, :, 0])
|
||||
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
|
||||
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
|
||||
scan_output = scan_output + (hidden_states * self.D[None, :, None])
|
||||
scan_output = (scan_output * self.act(gate))
|
||||
|
||||
|
@ -296,7 +296,7 @@ class MambaMixer(nn.Module):
|
|||
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
||||
|
||||
# 4. Final linear projection
|
||||
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
|
||||
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
|
||||
return contextualized_states
|
||||
# fmt: on
|
||||
|
||||
|
|
Loading…
Reference in New Issue