diff --git a/train.py b/train.py index 59fe14c..4cdaba5 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ model_type = 'RWKV' # 'RWKV' or 'RotaryMHA' or 'MHA-Plus' datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip model_level = 'character' # 'character' or 'word' -ctx_size = 256 if 'character' else 128 +ctx_size = 256 if model_level == 'character' else 128 nLayers = 5 nHead = 8 nEmb = 512 @@ -87,7 +87,10 @@ LENGTH_OF_EACH = 300 for run in range(NUM_OF_RUNS): context = "It was" - x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) + if model_level == 'word': + x = np.array([train_dataset.stoi[s] for s in context.split(' ')], dtype=np.int64) + else: + x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) real_len = len(x) if real_len < MAX_LEN: @@ -114,7 +117,11 @@ for run in range(NUM_OF_RUNS): real_len += 1 if i % 10 == 9 or i == LENGTH_OF_EACH-1: - completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]]) + if model_level == 'word': + completion = ' ' + ' '.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]]) + completion = completion.replace('\n ', '\n') + else: + completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]]) print(completion, end = '') print_begin = real_len print()