Skip to content

Commit 5d40afa

Browse files
Merge pull request #121 from OpenBioLink/input_list_and_str
Let inputs be strings, not only lists E.g. Collection("worldtree") is also possible, not only Collection(["worldtree")]
2 parents eeb957a + 2739675 commit 5d40afa

4 files changed

Lines changed: 46 additions & 8 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ collection.evaluate()
175175
## Versioning
176176
All updates/changes to datasets are explicitly mentioned in bold.
177177

178+
0.0.5 (2023-03-10) - Function to select which generated CoTs to keep after loading: collection.select_generated_cots(author="thoughtsource")
179+
178180
0.0.4 (2023-03-08) - Evaluation function improved. Function to load ThoughtSource100 collection: Collection.load_thoughtsource_100()
179181

180182
0.0.3 (2023-02-24) - ThoughtSource_100 collection released with reasoning chains from GPT-text-davinci-003, flan-t5-xxl, and cohere's command-xl

libs/cot/cot/config.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __post_init__(self):
7575
# replace all keys (or non given keys) in config with the corresponding values
7676

7777
# Inserts None at index 0 of instruction_keys to query without an explicit instruction
78-
# TODO rethink this, maybe add option to disable this
7978
if self.instruction_keys == "all":
8079
self.instruction_keys = [None] + list(FRAGMENTS["instructions"].keys())
8180
elif not self.instruction_keys:
@@ -91,6 +90,25 @@ def __post_init__(self):
9190
elif not self.answer_extraction_keys:
9291
self.answer_extraction_keys = [None]
9392

93+
# turn strings into lists for all trigger keys
94+
if isinstance(self.instruction_keys, str):
95+
self.instruction_keys = [self.instruction_keys]
96+
if isinstance(self.cot_trigger_keys, str):
97+
self.cot_trigger_keys = [self.cot_trigger_keys]
98+
if isinstance(self.answer_extraction_keys, str):
99+
self.answer_extraction_keys = [self.answer_extraction_keys]
100+
101+
# check if all keys are valid
102+
for key in self.instruction_keys:
103+
if key is not None and key not in FRAGMENTS["instructions"]:
104+
raise ValueError(f"Given instruction key '{key}' is not in fragments.json.")
105+
for key in self.cot_trigger_keys:
106+
if key is not None and key not in FRAGMENTS["cot_triggers"]:
107+
raise ValueError(f"Given cot_trigger key '{key}' is not in fragments.json.")
108+
for key in self.answer_extraction_keys:
109+
if key is not None and key not in FRAGMENTS["answer_extractions"]:
110+
raise ValueError(f"Given answer_extraction key '{key}' is not in fragments.json.")
111+
94112
# check if the templates contain only allowed keys
95113
import re
96114

@@ -115,15 +133,12 @@ def __post_init__(self):
115133
assert self.idx_range[0] < self.idx_range[1], "idx_range must be a tuple of ints with idx_range[0] < idx_range[1]"
116134

117135
if self.instruction_keys != "all":
118-
assert isinstance(self.instruction_keys, list), "instruction_keys must be a list"
119136
assert all(isinstance(key, (str, type(None))) for key in self.instruction_keys), "instruction_keys must be a list of strings"
120137

121138
if self.cot_trigger_keys != "all":
122-
assert isinstance(self.cot_trigger_keys, list), "cot_trigger_keys must be a list"
123139
assert all(isinstance(key, (str, type(None))) for key in self.cot_trigger_keys), "cot_trigger_keys must be a list of strings"
124140

125141
if self.answer_extraction_keys != "all":
126-
assert isinstance(self.answer_extraction_keys, list), "answer_extraction_keys must be a list"
127142
assert all(
128143
isinstance(key, (str, type(None))) for key in self.answer_extraction_keys
129144
), "answer_extraction_keys must be a list of strings"

libs/cot/cot/dataloader.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@ def __init__(self, names=None, verbose=True, generate_mode=None, source=False, l
6161
"load_pregenerated_cots only works if datasets are loaded in ThoughSource view. \
6262
Param source needs to be False for pregenerated CoTs to be loaded."
6363
)
64+
65+
# if dataset name is a string, convert to list
66+
if isinstance(names, str) and names != "all":
67+
names = [names]
68+
# test if dataset name is valid
69+
if names is not None and names != "all":
70+
for name in names:
71+
available_datasets = Collection._all_available_datasets()
72+
if name not in available_datasets:
73+
raise ValueError(
74+
f"""Dataset '{name}' not found. Please check spelling.
75+
Available datasets: {available_datasets}"""
76+
)
6477

6578
if generate_mode in ["redownload", "recache"]:
6679
# see https://huggingface.co/docs/datasets/v2.1.0/en/package_reference/builder_classes#datasets.DownloadMode
@@ -90,7 +103,7 @@ def __init__(self, names=None, verbose=True, generate_mode=None, source=False, l
90103
self.load_datasets(names)
91104

92105
# unfortunately all generated cots have to be loaded when loading datasets in ThoughtSource view
93-
# here: all or None, or select specific generated cots with select_generated_cots
106+
# here: all or None, selection of specific generated cots can be done later with select_generated_cots
94107
if not load_pregenerated_cots and not source:
95108
self.delete_all_generated_cots()
96109

@@ -151,6 +164,10 @@ def _find_datasets(names=None):
151164
else:
152165
dataloader_scripts = [(name, path_to_biodatasets / name / (name + ".py")) for name in names]
153166
return dataloader_scripts
167+
168+
@staticmethod
169+
def _all_available_datasets():
170+
return [name for name, _ in Collection._find_datasets()]
154171

155172
def _get_metadata(self):
156173
for name, script_path in Collection._find_datasets():

libs/cot/cot/generate.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,14 @@ def _generate_and_extract(
208208
item["generated_cot"].append(generated_cot)
209209

210210
except Exception as ex:
211+
# if last try, raise error
212+
if i == number_of_tries - 1:
213+
raise ex
214+
215+
# if not last try, add additional time to api_time_interval and try again
211216
additional_api_time += 10
212-
print("API-Error in item " + str(idx) + ": " + str(ex))
217+
print("(API-)Error in item " + str(idx) + ": " + str(ex))
213218
print("Retrying with additional time of " + str(additional_api_time) + " seconds.")
214-
# if you want the error to be raised, uncomment the following line:
215-
# raise ex
216219
pass
217220

218221
else:
@@ -406,6 +409,7 @@ def __getitem__(self, key):
406409

407410
def query_model(input, api_service, engine, temperature, max_tokens, api_time_interval):
408411
if api_service == "mock_api":
412+
# time.sleep(api_time_interval)
409413
return " Test mock chain of thought."
410414
# return ("This is a " + 20 * "long " + "Mock CoT.\n")*20
411415

0 commit comments

Comments
 (0)