|
|
|
|
@ -83,10 +83,10 @@ def main(
|
|
|
|
|
ckpt_dir: str,
|
|
|
|
|
tokenizer_path: str,
|
|
|
|
|
temperature: float = 0.8,
|
|
|
|
|
top_p: float = 0.0, # use 0.95 or so for top_p sampler, and 0.0 for top_k sampler
|
|
|
|
|
top_p: float = 0.95, # use 0.95 or so for top_p sampler, and 0.0 for top_k sampler
|
|
|
|
|
top_k: int = 40,
|
|
|
|
|
repetition_penalty: float = (1.0 / 0.85), # 1.0 to disable repetition_penalty
|
|
|
|
|
sampler: str = 'top_k', # top_k or top_p
|
|
|
|
|
sampler: str = 'top_p', # top_p or top_k
|
|
|
|
|
max_seq_len: int = 2048,
|
|
|
|
|
max_batch_size: int = 1,
|
|
|
|
|
):
|
|
|
|
|
|