⚠️ [CLAP] Fix dtype of logit scales in init (#25682)
[CLAP] Fix dtype of logit scales
This commit is contained in:
parent
2cf87e2bbb
commit
77cb2ab792
|
@ -18,7 +18,6 @@ import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -1956,8 +1955,8 @@ class ClapModel(ClapPreTrainedModel):
|
||||||
text_config = config.text_config
|
text_config = config.text_config
|
||||||
audio_config = config.audio_config
|
audio_config = config.audio_config
|
||||||
|
|
||||||
self.logit_scale_a = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value)))
|
self.logit_scale_a = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
|
||||||
self.logit_scale_t = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value)))
|
self.logit_scale_t = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
|
||||||
|
|
||||||
self.projection_dim = config.projection_dim
|
self.projection_dim = config.projection_dim
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue