better splitting of words

main
BlinkDL 4 years ago
parent 01d6972f4f
commit 55405c57d0

@ -26,24 +26,30 @@ nLayers = 5
nHead = 8 nHead = 8
nEmb = 512 nEmb = 512
nepoch = 50 nepoch = 50 # just a quick test. the 'epoch' here is very short
nbatchsz = 64 nbatchsz = 64
epoch_length_fixed = 10000 # make an epoch very short, so we can see the training progress epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
######################################################################################################## ########################################################################################################
print("loading data...", end="") print('loading data... ' + datafile)
class Dataset(Dataset): class Dataset(Dataset):
def __init__(self, data, model_level, ctx_size): def __init__(self, data, model_level, ctx_size):
print('building token list...')
if model_level == 'word': if model_level == 'word':
data = data.replace('\n', ' \n ').replace(' ', ' ').split(' ') import re
data = re.sub(r'(\n|\.|\,|\?|\!|\:|\;|\-|\—|\||\'|\"|\`|\(|\)|[0-9]|\[|\]|\{|\}|\=|\+|\*|\\|\/|\~|\&|\$|\#|\%)', r' \g<0> ', data)
data = re.sub(' +',' ',data)
print('splitting token...')
data = data.lower().split(' ')
unique = sorted(list(set(data))) unique = sorted(list(set(data)))
for u in unique:
print(u, end=' ')
data_size, vocab_size = len(data), len(unique) data_size, vocab_size = len(data), len(unique)
print('\n\ndata has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
self.stoi = { ch:i for i,ch in enumerate(unique) } self.stoi = { ch:i for i,ch in enumerate(unique) }
self.itos = { i:ch for i,ch in enumerate(unique) } self.itos = { i:ch for i,ch in enumerate(unique) }
print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
self.ctx_size = ctx_size self.ctx_size = ctx_size
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.data = data self.data = data
@ -88,7 +94,7 @@ for run in range(NUM_OF_RUNS):
context = "It was" context = "It was"
if model_level == 'word': if model_level == 'word':
x = np.array([train_dataset.stoi[s] for s in context.split(' ')], dtype=np.int64) x = np.array([train_dataset.stoi[s] for s in context.strip().lower().split(' ')], dtype=np.int64)
else: else:
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)

Loading…
Cancel
Save