diff --git a/train.py b/train.py index 4cdaba5..429f741 100644 --- a/train.py +++ b/train.py @@ -26,24 +26,30 @@ nLayers = 5 nHead = 8 nEmb = 512 -nepoch = 50 +nepoch = 50 # just a quick test. the 'epoch' here is very short nbatchsz = 64 -epoch_length_fixed = 10000 # make an epoch very short, so we can see the training progress +epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress ######################################################################################################## -print("loading data...", end="") +print('loading data... ' + datafile) class Dataset(Dataset): def __init__(self, data, model_level, ctx_size): + print('building token list...') if model_level == 'word': - data = data.replace('\n', ' \n ').replace(' ', ' ').split(' ') - + import re + data = re.sub(r'(\n|\.|\,|\?|\!|\:|\;|\-|\—|\||\'|\"|\`|\(|\)|[0-9]|\[|\]|\{|\}|\=|\+|\*|\\|\/|\~|\&|\$|\#|\%)', r' \g<0> ', data) + data = re.sub(' +',' ',data) + print('splitting token...') + data = data.lower().split(' ') unique = sorted(list(set(data))) + for u in unique: + print(u, end=' ') data_size, vocab_size = len(data), len(unique) + print('\n\ndata has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) self.stoi = { ch:i for i,ch in enumerate(unique) } self.itos = { i:ch for i,ch in enumerate(unique) } - print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) self.ctx_size = ctx_size self.vocab_size = vocab_size self.data = data @@ -88,7 +94,7 @@ for run in range(NUM_OF_RUNS): context = "It was" if model_level == 'word': - x = np.array([train_dataset.stoi[s] for s in context.split(' ')], dtype=np.int64) + x = np.array([train_dataset.stoi[s] for s in context.strip().lower().split(' ')], dtype=np.int64) else: x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)