|
|
|
@ -6,8 +6,6 @@ import os, sys, time, math, random, json, datetime
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from src.trainer import Trainer, TrainerConfig
|
|
|
|
from src.trainer import Trainer, TrainerConfig
|
|
|
|
from src.model import GPT, GPTConfig
|
|
|
|
from src.model import GPT, GPTConfig
|
|
|
|
@ -42,6 +40,8 @@ nepoch = 50 # just a quick test. the 'epoch' here is very short
|
|
|
|
nbatchsz = 64
|
|
|
|
nbatchsz = 64
|
|
|
|
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
|
|
|
|
epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
# Load data
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
print('loading data... ' + datafile)
|
|
|
|
print('loading data... ' + datafile)
|
|
|
|
@ -79,6 +79,8 @@ class Dataset(Dataset):
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_size)
|
|
|
|
train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
# Train model
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type,
|
|
|
|
model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type,
|
|
|
|
@ -94,11 +96,12 @@ trainer.train()
|
|
|
|
|
|
|
|
|
|
|
|
torch.save(model, 'trained-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
|
|
|
|
torch.save(model, 'trained-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
# Run model to generate text
|
|
|
|
########################################################################################################
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
from src.utils import sample_logits
|
|
|
|
from src.utils import sample_logits
|
|
|
|
|
|
|
|
|
|
|
|
MAX_LEN = ctx_size
|
|
|
|
|
|
|
|
NUM_OF_RUNS = 5
|
|
|
|
NUM_OF_RUNS = 5
|
|
|
|
LENGTH_OF_EACH = 300
|
|
|
|
LENGTH_OF_EACH = 300
|
|
|
|
|
|
|
|
|
|
|
|
@ -111,8 +114,8 @@ for run in range(NUM_OF_RUNS):
|
|
|
|
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
|
|
|
|
x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64)
|
|
|
|
|
|
|
|
|
|
|
|
real_len = len(x)
|
|
|
|
real_len = len(x)
|
|
|
|
if real_len < MAX_LEN:
|
|
|
|
if real_len < ctx_size:
|
|
|
|
x = np.pad(x, (0, MAX_LEN - real_len))
|
|
|
|
x = np.pad(x, (0, ctx_size - real_len))
|
|
|
|
print_begin = 0
|
|
|
|
print_begin = 0
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(LENGTH_OF_EACH):
|
|
|
|
for i in range(LENGTH_OF_EACH):
|
|
|
|
@ -122,13 +125,13 @@ for run in range(NUM_OF_RUNS):
|
|
|
|
print_begin = real_len
|
|
|
|
print_begin = real_len
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|
xxx = torch.tensor(x[-MAX_LEN:], dtype=torch.long)[None,...].to("cuda:0")
|
|
|
|
xxx = torch.tensor(x[-ctx_size:], dtype=torch.long)[None,...].to("cuda:0")
|
|
|
|
out, _ = model(xxx)
|
|
|
|
out, _ = model(xxx)
|
|
|
|
pos = -1 if real_len >= MAX_LEN else real_len - 1
|
|
|
|
pos = -1 if real_len >= ctx_size else real_len - 1
|
|
|
|
|
|
|
|
|
|
|
|
char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02)
|
|
|
|
char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method
|
|
|
|
|
|
|
|
|
|
|
|
if real_len < MAX_LEN:
|
|
|
|
if real_len < ctx_size:
|
|
|
|
x[real_len] = char
|
|
|
|
x[real_len] = char
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
x = np.append(x, char)
|
|
|
|
x = np.append(x, char)
|
|
|
|
|