now works for word-level LM

main
BlinkDL 4 years ago
parent 64fdb61056
commit 01d6972f4f

@ -21,7 +21,7 @@ model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus'
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
model_level = 'character' # 'character' or 'word' model_level = 'character' # 'character' or 'word'
ctx_size = 256 if 'character' else 128 ctx_size = 256 if model_level == 'character' else 128
nLayers = 5 nLayers = 5
nHead = 8 nHead = 8
nEmb = 512 nEmb = 512
@ -87,7 +87,10 @@ LENGTH_OF_EACH = 300
for run in range(NUM_OF_RUNS): for run in range(NUM_OF_RUNS):
context = "It was" context = "It was"
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) if model_level == 'word':
x = np.array([train_dataset.stoi[s] for s in context.split(' ')], dtype=np.int64)
else:
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
real_len = len(x) real_len = len(x)
if real_len < MAX_LEN: if real_len < MAX_LEN:
@ -114,7 +117,11 @@ for run in range(NUM_OF_RUNS):
real_len += 1 real_len += 1
if i % 10 == 9 or i == LENGTH_OF_EACH-1: if i % 10 == 9 or i == LENGTH_OF_EACH-1:
completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]]) if model_level == 'word':
completion = ' ' + ' '.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
completion = completion.replace('\n ', '\n')
else:
completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
print(completion, end = '') print(completion, end = '')
print_begin = real_len print_begin = real_len
print() print()

Loading…
Cancel
Save