Fix RWKV backward on GPU (#23774)
This commit is contained in:
parent
8d28dba35d
commit
4d9b76a80f
|
@ -159,7 +159,7 @@ class RwkvLinearAttention(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# g stands for grad
|
# g stands for grad
|
||||||
def backward(ctx, g_output):
|
def backward(ctx, g_output, g_state=None):
|
||||||
input_dtype = ctx.input_dtype
|
input_dtype = ctx.input_dtype
|
||||||
|
|
||||||
time_decay, time_first, key, value, output = ctx.saved_tensors
|
time_decay, time_first, key, value, output = ctx.saved_tensors
|
||||||
|
@ -188,17 +188,14 @@ class RwkvLinearAttention(torch.autograd.Function):
|
||||||
g_key,
|
g_key,
|
||||||
g_value,
|
g_value,
|
||||||
)
|
)
|
||||||
g_time_decay = torch.sum(g_time_decay, dim=0)
|
|
||||||
g_time_first = torch.sum(g_time_first, dim=0)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
g_time_decay.to(input_dtype),
|
g_time_decay.to(input_dtype),
|
||||||
g_time_first.to(input_dtype),
|
g_time_first.to(input_dtype),
|
||||||
g_key.to(input_dtype),
|
g_key.to(input_dtype),
|
||||||
g_value.to(input_dtype),
|
g_value.to(input_dtype),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue