-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathcommon_interactive.py
48 lines (40 loc) · 1.34 KB
/
common_interactive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from cmdline import args
def strip_instructions(prompt):
try:
return prompt[prompt.index('[/INST]'):]
except ValueError:
return prompt
def diffprompt_default(prompt, results):
n = len(strip_instructions(prompt))
return [strip_instructions(r)[n:] for r in results]
def find_assistant(prompt, initial_prompt=""):
tag = "<|eot_id|>"
try:
end_index = prompt.rindex(tag)
start_index = prompt[:end_index].rindex(tag)
except ValueError:
return prompt[len(initial_prompt):]
r = prompt[start_index+len(tag):end_index]
r = r.replace("<|start_header_id|>user<|end_header_id|>", "")
r = r.replace("<|start_header_id|>assistant<|end_header_id|>", "")
#print("r is [[\n", r, "\n]]")
return r
def diffprompt_llama3(prompt, results):
return [find_assistant(r, prompt) for r in results]
def choose_diffprompt(model_name):
x = model_name.lower()
if "llama3" in x or "llama-3-" in x or "llama-3." in x:
return diffprompt_llama3
else:
return diffprompt_default
diffprompt = choose_diffprompt(args.base_model_name)
def ask_keep(prompt, texts):
i = 0
for t in diffprompt(prompt, texts):
print(i, t)
i += 1
inp = input("Keep which? [0...] or comment: ").strip()
try:
return int(inp)
except ValueError:
return inp