Fixed minor bug when running training on cuda
This commit is contained in:
parent
0b51fba20b
commit
7469d03b1c
|
@ -18,6 +18,7 @@ import torch.utils.data as data
|
||||||
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
||||||
from torchtext import data as torchtext_data
|
from torchtext import data as torchtext_data
|
||||||
from torchtext import datasets
|
from torchtext import datasets
|
||||||
|
|
||||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -89,7 +90,7 @@ class Discriminator(torch.nn.Module):
|
||||||
if self.cached_mode:
|
if self.cached_mode:
|
||||||
avg_hidden = x.to(device)
|
avg_hidden = x.to(device)
|
||||||
else:
|
else:
|
||||||
avg_hidden = self.avg_representation(x)
|
avg_hidden = self.avg_representation(x.to(device))
|
||||||
|
|
||||||
logits = self.classifier_head(avg_hidden)
|
logits = self.classifier_head(avg_hidden)
|
||||||
probs = F.log_softmax(logits, dim=-1)
|
probs = F.log_softmax(logits, dim=-1)
|
||||||
|
@ -203,7 +204,7 @@ def evaluate_performance(data_loader, discriminator):
|
||||||
|
|
||||||
def predict(input_sentence, model, classes, cached=False):
|
def predict(input_sentence, model, classes, cached=False):
|
||||||
input_t = model.tokenizer.encode(input_sentence)
|
input_t = model.tokenizer.encode(input_sentence)
|
||||||
input_t = torch.tensor([input_t], dtype=torch.long)
|
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
|
||||||
if cached:
|
if cached:
|
||||||
input_t = model.avg_representation(input_t)
|
input_t = model.avg_representation(input_t)
|
||||||
|
|
||||||
|
@ -428,7 +429,8 @@ def train_discriminator(
|
||||||
with open(dataset_fp) as f:
|
with open(dataset_fp) as f:
|
||||||
csv_reader = csv.reader(f, delimiter='\t')
|
csv_reader = csv.reader(f, delimiter='\t')
|
||||||
for row in csv_reader:
|
for row in csv_reader:
|
||||||
classes.add(row[0])
|
if row:
|
||||||
|
classes.add(row[0])
|
||||||
|
|
||||||
idx2class = sorted(classes)
|
idx2class = sorted(classes)
|
||||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||||
|
@ -444,30 +446,31 @@ def train_discriminator(
|
||||||
with open(dataset_fp) as f:
|
with open(dataset_fp) as f:
|
||||||
csv_reader = csv.reader(f, delimiter='\t')
|
csv_reader = csv.reader(f, delimiter='\t')
|
||||||
for i, row in enumerate(csv_reader):
|
for i, row in enumerate(csv_reader):
|
||||||
label = row[0]
|
if row:
|
||||||
text = row[1]
|
label = row[0]
|
||||||
|
text = row[1]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
seq = discriminator.tokenizer.encode(text)
|
seq = discriminator.tokenizer.encode(text)
|
||||||
if (len(seq) < max_length_seq):
|
if (len(seq) < max_length_seq):
|
||||||
seq = torch.tensor(
|
seq = torch.tensor(
|
||||||
[50256] + seq,
|
[50256] + seq,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.long
|
dtype=torch.long
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("Line {} is longer than maximum length {}".format(
|
print("Line {} is longer than maximum length {}".format(
|
||||||
i, max_length_seq
|
i, max_length_seq
|
||||||
))
|
))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
x.append(seq)
|
x.append(seq)
|
||||||
y.append(class2idx[label])
|
y.append(class2idx[label])
|
||||||
|
|
||||||
except:
|
except:
|
||||||
print("Error tokenizing line {}, skipping it".format(i))
|
print("Error tokenizing line {}, skipping it".format(i))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
full_dataset = Dataset(x, y)
|
full_dataset = Dataset(x, y)
|
||||||
train_size = int(0.9 * len(full_dataset))
|
train_size = int(0.9 * len(full_dataset))
|
||||||
|
|
Loading…
Reference in New Issue