from lib2to3.pgen2 import token import os import torch import numpy as np import shutil import struct from functools import lru_cache from itertools import accumulate def print_rank_0(*message): """If distributed is initialized print only on rank 0.""" if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: print(*message, flush=True) else: print(*message, flush=True) def _warmup_mmap_file(path): pass # with open(path, "rb") as stream: # while stream.read(100 * 1024 * 1024): # pass dtypes = { 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16, } def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) def index_file_path(prefix_path): return prefix_path + ".idx" def data_file_path(prefix_path): return prefix_path + ".bin" class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b"MMIDIDX\x00\x00" def __init__(self, path, skip_warmup=False): with open(path, "rb") as stream: magic_test = stream.read(9) assert self._HDR_MAGIC == magic_test, ( "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." ) # Little endian unsigned 64 Bit integer version = struct.unpack("