|
|
|
|
@ -21,7 +21,7 @@ model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus'
|
|
|
|
|
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
|
|
|
|
|
model_level = 'character' # 'character' or 'word'
|
|
|
|
|
|
|
|
|
|
ctx_size = 256 if 'character' else 128
|
|
|
|
|
ctx_size = 256 if model_level == 'character' else 128
|
|
|
|
|
nLayers = 5
|
|
|
|
|
nHead = 8
|
|
|
|
|
nEmb = 512
|
|
|
|
|
@ -87,6 +87,9 @@ LENGTH_OF_EACH = 300
|
|
|
|
|
for run in range(NUM_OF_RUNS):
|
|
|
|
|
context = "It was"
|
|
|
|
|
|
|
|
|
|
if model_level == 'word':
|
|
|
|
|
x = np.array([train_dataset.stoi[s] for s in context.split(' ')], dtype=np.int64)
|
|
|
|
|
else:
|
|
|
|
|
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
|
|
|
|
|
|
|
|
|
|
real_len = len(x)
|
|
|
|
|
@ -114,6 +117,10 @@ for run in range(NUM_OF_RUNS):
|
|
|
|
|
real_len += 1
|
|
|
|
|
|
|
|
|
|
if i % 10 == 9 or i == LENGTH_OF_EACH-1:
|
|
|
|
|
if model_level == 'word':
|
|
|
|
|
completion = ' ' + ' '.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
|
|
|
|
|
completion = completion.replace('\n ', '\n')
|
|
|
|
|
else:
|
|
|
|
|
completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
|
|
|
|
|
print(completion, end = '')
|
|
|
|
|
print_begin = real_len
|
|
|
|
|
|