Model Card: gaochangkuan README.md (#4033)

* Create README.md

* Update README.md

* tweak

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Scottish_Fold007 2020-05-01 10:26:58 +08:00 committed by GitHub
parent 8829ace4aa
commit 6b410bedfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 66 additions and 0 deletions

View File

@ -0,0 +1,66 @@
## Generating Chinese poetry by topic.
```python
from transformers import *
tokenizer = BertTokenizer.from_pretrained("gaochangkuan/model_dir")
model = AutoModelWithLMHead.from_pretrained("gaochangkuan/model_dir")
prompt= '''<s>田园躬耕'''
length= 84
stop_token='</s>'
temperature = 1.2
repetition_penalty=1.3
k= 30
p= 0.95
device ='cuda'
seed=2020
no_cuda=False
prompt_text = prompt if prompt else input("Model prompt >>> ")
encoded_prompt = tokenizer.encode(
'<s>'+prompt_text+'<sep>',
add_special_tokens=False,
return_tensors="pt"
)
encoded_prompt = encoded_prompt.to(device)
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=length,
min_length=10,
do_sample=True,
early_stopping=True,
num_beams=10,
temperature=temperature,
top_k=k,
top_p=p,
repetition_penalty=repetition_penalty,
bad_words_ids=None,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
length_penalty=1.2,
no_repeat_ngram_size=2,
num_return_sequences=1,
attention_mask=None,
decoder_start_token_id=tokenizer.bos_token_id,)
generated_sequence = output_sequences[0].tolist()
text = tokenizer.decode(generated_sequence)
text = text[: text.find(stop_token) if stop_token else None]
print(''.join(text).replace(' ','').replace('<pad>','').replace('<s>',''))
```