now works for word-level LM

main
BlinkDL 4 years ago
parent 64fdb61056
commit 01d6972f4f

@ -21,7 +21,7 @@ model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus'
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
model_level = 'character' # 'character' or 'word'
ctx_size = 256 if 'character' else 128
ctx_size = 256 if model_level == 'character' else 128
nLayers = 5
nHead = 8
nEmb = 512
@ -87,6 +87,9 @@ LENGTH_OF_EACH = 300
for run in range(NUM_OF_RUNS):
context = "It was"
if model_level == 'word':
x = np.array([train_dataset.stoi[s] for s in context.split(' ')], dtype=np.int64)
else:
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
real_len = len(x)
@ -114,6 +117,10 @@ for run in range(NUM_OF_RUNS):
real_len += 1
if i % 10 == 9 or i == LENGTH_OF_EACH-1:
if model_level == 'word':
completion = ' ' + ' '.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
completion = completion.replace('\n ', '\n')
else:
completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
print(completion, end = '')
print_begin = real_len

Loading…
Cancel
Save