Fix RWKV backward on GPU (#23774)

This commit is contained in:
Sylvain Gugger 2023-05-26 08:33:17 -04:00 committed by GitHub
parent 8d28dba35d
commit 4d9b76a80f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 6 deletions

View File

@ -159,7 +159,7 @@ class RwkvLinearAttention(torch.autograd.Function):
@staticmethod
# g stands for grad
def backward(ctx, g_output):
def backward(ctx, g_output, g_state=None):
input_dtype = ctx.input_dtype
time_decay, time_first, key, value, output = ctx.saved_tensors
@ -188,17 +188,14 @@ class RwkvLinearAttention(torch.autograd.Function):
g_key,
g_value,
)
g_time_decay = torch.sum(g_time_decay, dim=0)
g_time_first = torch.sum(g_time_first, dim=0)
return (
None,
None,
None,
g_time_decay.to(input_dtype),
g_time_first.to(input_dtype),
g_key.to(input_dtype),
g_value.to(input_dtype),
None,
None,
)