⚠️ [CLAP] Fix dtype of logit scales in init (#25682)

[CLAP] Fix dtype of logit scales
This commit is contained in:
Sanchit Gandhi 2023-08-23 13:17:37 +01:00 committed by GitHub
parent 2cf87e2bbb
commit 77cb2ab792
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 3 deletions

View File

@ -18,7 +18,6 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
@ -1956,8 +1955,8 @@ class ClapModel(ClapPreTrainedModel):
text_config = config.text_config text_config = config.text_config
audio_config = config.audio_config audio_config = config.audio_config
self.logit_scale_a = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value))) self.logit_scale_a = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
self.logit_scale_t = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value))) self.logit_scale_t = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
self.projection_dim = config.projection_dim self.projection_dim = config.projection_dim