⚠️ [CLAP] Fix dtype of logit scales in init (#25682)

[CLAP] Fix dtype of logit scales
This commit is contained in:
Sanchit Gandhi 2023-08-23 13:17:37 +01:00 committed by GitHub
parent 2cf87e2bbb
commit 77cb2ab792
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 3 deletions

View File

@ -18,7 +18,6 @@ import math
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
@ -1956,8 +1955,8 @@ class ClapModel(ClapPreTrainedModel):
text_config = config.text_config
audio_config = config.audio_config
self.logit_scale_a = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value)))
self.logit_scale_t = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value)))
self.logit_scale_a = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
self.logit_scale_t = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
self.projection_dim = config.projection_dim