From 6266f481da97a9c6ab9748630cd307c0d81d207b Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Wed, 11 Aug 2021 15:53:44 +0800 Subject: [PATCH] minor changes --- src/model.py | 1 - train.py | 21 ++++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/model.py b/src/model.py index 494ce5a..d5f36a2 100644 --- a/src/model.py +++ b/src/model.py @@ -4,7 +4,6 @@ import math import logging -import numpy as np import torch import torch.nn as nn from torch.nn import functional as F diff --git a/train.py b/train.py index 97ff024..190a18e 100644 --- a/train.py +++ b/train.py @@ -6,8 +6,6 @@ import os, sys, time, math, random, json, datetime import logging import numpy as np import torch -import torch.nn as nn -from torch.nn import functional as F from torch.utils.data import Dataset from src.trainer import Trainer, TrainerConfig from src.model import GPT, GPTConfig @@ -42,6 +40,8 @@ nepoch = 50 # just a quick test. the 'epoch' here is very short nbatchsz = 64 epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress +######################################################################################################## +# Load data ######################################################################################################## print('loading data... ' + datafile) @@ -79,6 +79,8 @@ class Dataset(Dataset): train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_size) +######################################################################################################## +# Train model ######################################################################################################## model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_size, model_type=model_type, @@ -94,11 +96,12 @@ trainer.train() torch.save(model, 'trained-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') +######################################################################################################## +# Run model to generate text ######################################################################################################## from src.utils import sample_logits -MAX_LEN = ctx_size NUM_OF_RUNS = 5 LENGTH_OF_EACH = 300 @@ -111,8 +114,8 @@ for run in range(NUM_OF_RUNS): x = np.array([train_dataset.stoi[s] for s in context], dtype=np.int64) real_len = len(x) - if real_len < MAX_LEN: - x = np.pad(x, (0, MAX_LEN - real_len)) + if real_len < ctx_size: + x = np.pad(x, (0, ctx_size - real_len)) print_begin = 0 for i in range(LENGTH_OF_EACH): @@ -122,13 +125,13 @@ for run in range(NUM_OF_RUNS): print_begin = real_len with torch.no_grad(): - xxx = torch.tensor(x[-MAX_LEN:], dtype=torch.long)[None,...].to("cuda:0") + xxx = torch.tensor(x[-ctx_size:], dtype=torch.long)[None,...].to("cuda:0") out, _ = model(xxx) - pos = -1 if real_len >= MAX_LEN else real_len - 1 + pos = -1 if real_len >= ctx_size else real_len - 1 - char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) + char = sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=0.02) # our special sampling method - if real_len < MAX_LEN: + if real_len < ctx_size: x[real_len] = char else: x = np.append(x, char)