⚠️ [CLAP] Fix dtype of logit scales in init (#25682)
[CLAP] Fix dtype of logit scales
This commit is contained in:
parent
2cf87e2bbb
commit
77cb2ab792
|
@ -18,7 +18,6 @@ import math
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
@ -1956,8 +1955,8 @@ class ClapModel(ClapPreTrainedModel):
|
|||
text_config = config.text_config
|
||||
audio_config = config.audio_config
|
||||
|
||||
self.logit_scale_a = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value)))
|
||||
self.logit_scale_t = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value)))
|
||||
self.logit_scale_a = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
|
||||
self.logit_scale_t = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
|
||||
|
||||
self.projection_dim = config.projection_dim
|
||||
|
||||
|
|
Loading…
Reference in New Issue