minor changes

main
BlinkDL 4 years ago
parent 88297e7949
commit 6266f481da

@ -4,7 +4,6 @@
import math import math
import logging import logging
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F

@ -6,8 +6,6 @@ import os, sys, time, math, random, json, datetime
import logging import logging
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset from torch.utils.data import Dataset
from src.trainer import Trainer, TrainerConfig from src.trainer import Trainer, TrainerConfig
from src.model import GPT, GPTConfig from src.model import GPT, GPTConfig
@ -42,6 +40,8 @@ nepoch = 50 # just a quick test. the 'epoch' here is very short
nbatchsz = 64 nbatchsz = 64
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
########################################################################################################
# Load data
######################################################################################################## ########################################################################################################
print('loading data... ' + datafile) print('loading data... ' + datafile)
@ -79,6 +79,8 @@ class Dataset(Dataset):
train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_size) train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_size)
########################################################################################################
# Train model
######################################################################################################## ########################################################################################################
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type, model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type,
@ -94,11 +96,12 @@ trainer.train()
torch.save(model, 'trained-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') torch.save(model, 'trained-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
########################################################################################################
# Run model to generate text
######################################################################################################## ########################################################################################################
from src.utils import sample_logits from src.utils import sample_logits
MAX_LEN = ctx_size
NUM_OF_RUNS = 5 NUM_OF_RUNS = 5
LENGTH_OF_EACH = 300 LENGTH_OF_EACH = 300
@ -111,8 +114,8 @@ for run in range(NUM_OF_RUNS):
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
real_len = len(x) real_len = len(x)
if real_len < MAX_LEN: if real_len < ctx_size:
x = np.pad(x, (0, MAX_LEN - real_len)) x = np.pad(x, (0, ctx_size - real_len))
print_begin = 0 print_begin = 0
for i in range(LENGTH_OF_EACH): for i in range(LENGTH_OF_EACH):
@ -122,13 +125,13 @@ for run in range(NUM_OF_RUNS):
print_begin = real_len print_begin = real_len
with torch.no_grad(): with torch.no_grad():
xxx = torch.tensor(x[-MAX_LEN:], dtype=torch.long)[None,...].to("cuda:0") xxx = torch.tensor(x[-ctx_size:], dtype=torch.long)[None,...].to("cuda:0")
out, _ = model(xxx) out, _ = model(xxx)
pos = -1 if real_len >= MAX_LEN else real_len - 1 pos = -1 if real_len >= ctx_size else real_len - 1
char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method
if real_len < MAX_LEN: if real_len < ctx_size:
x[real_len] = char x[real_len] = char
else: else:
x = np.append(x, char) x = np.append(x, char)

Loading…
Cancel
Save