Fixed minor bug when running training on cuda

This commit is contained in:
w4nderlust 2019-11-26 21:30:57 -08:00 committed by Julien Chaumond
parent 0b51fba20b
commit 7469d03b1c
1 changed files with 26 additions and 23 deletions

View File

@ -18,6 +18,7 @@ import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer from nltk.tokenize.treebank import TreebankWordDetokenizer
from torchtext import data as torchtext_data from torchtext import data as torchtext_data
from torchtext import datasets from torchtext import datasets
from transformers import GPT2Tokenizer, GPT2LMHeadModel from transformers import GPT2Tokenizer, GPT2LMHeadModel
torch.manual_seed(0) torch.manual_seed(0)
@ -89,7 +90,7 @@ class Discriminator(torch.nn.Module):
if self.cached_mode: if self.cached_mode:
avg_hidden = x.to(device) avg_hidden = x.to(device)
else: else:
avg_hidden = self.avg_representation(x) avg_hidden = self.avg_representation(x.to(device))
logits = self.classifier_head(avg_hidden) logits = self.classifier_head(avg_hidden)
probs = F.log_softmax(logits, dim=-1) probs = F.log_softmax(logits, dim=-1)
@ -203,7 +204,7 @@ def evaluate_performance(data_loader, discriminator):
def predict(input_sentence, model, classes, cached=False): def predict(input_sentence, model, classes, cached=False):
input_t = model.tokenizer.encode(input_sentence) input_t = model.tokenizer.encode(input_sentence)
input_t = torch.tensor([input_t], dtype=torch.long) input_t = torch.tensor([input_t], dtype=torch.long, device=device)
if cached: if cached:
input_t = model.avg_representation(input_t) input_t = model.avg_representation(input_t)
@ -428,7 +429,8 @@ def train_discriminator(
with open(dataset_fp) as f: with open(dataset_fp) as f:
csv_reader = csv.reader(f, delimiter='\t') csv_reader = csv.reader(f, delimiter='\t')
for row in csv_reader: for row in csv_reader:
classes.add(row[0]) if row:
classes.add(row[0])
idx2class = sorted(classes) idx2class = sorted(classes)
class2idx = {c: i for i, c in enumerate(idx2class)} class2idx = {c: i for i, c in enumerate(idx2class)}
@ -444,30 +446,31 @@ def train_discriminator(
with open(dataset_fp) as f: with open(dataset_fp) as f:
csv_reader = csv.reader(f, delimiter='\t') csv_reader = csv.reader(f, delimiter='\t')
for i, row in enumerate(csv_reader): for i, row in enumerate(csv_reader):
label = row[0] if row:
text = row[1] label = row[0]
text = row[1]
try: try:
seq = discriminator.tokenizer.encode(text) seq = discriminator.tokenizer.encode(text)
if (len(seq) < max_length_seq): if (len(seq) < max_length_seq):
seq = torch.tensor( seq = torch.tensor(
[50256] + seq, [50256] + seq,
device=device, device=device,
dtype=torch.long dtype=torch.long
) )
else: else:
print("Line {} is longer than maximum length {}".format( print("Line {} is longer than maximum length {}".format(
i, max_length_seq i, max_length_seq
)) ))
continue continue
x.append(seq) x.append(seq)
y.append(class2idx[label]) y.append(class2idx[label])
except: except:
print("Error tokenizing line {}, skipping it".format(i)) print("Error tokenizing line {}, skipping it".format(i))
pass pass
full_dataset = Dataset(x, y) full_dataset = Dataset(x, y)
train_size = int(0.9 * len(full_dataset)) train_size = int(0.9 * len(full_dataset))