@@ -75,7 +75,6 @@ def __post_init__(self):
7575 # replace all keys (or non given keys) in config with the corresponding values
7676
7777 # Inserts None at index 0 of instruction_keys to query without an explicit instruction
78- # TODO rethink this, maybe add option to disable this
7978 if self .instruction_keys == "all" :
8079 self .instruction_keys = [None ] + list (FRAGMENTS ["instructions" ].keys ())
8180 elif not self .instruction_keys :
@@ -91,6 +90,25 @@ def __post_init__(self):
9190 elif not self .answer_extraction_keys :
9291 self .answer_extraction_keys = [None ]
9392
93+ # turn strings into lists for all trigger keys
94+ if isinstance (self .instruction_keys , str ):
95+ self .instruction_keys = [self .instruction_keys ]
96+ if isinstance (self .cot_trigger_keys , str ):
97+ self .cot_trigger_keys = [self .cot_trigger_keys ]
98+ if isinstance (self .answer_extraction_keys , str ):
99+ self .answer_extraction_keys = [self .answer_extraction_keys ]
100+
101+ # check if all keys are valid
102+ for key in self .instruction_keys :
103+ if key is not None and key not in FRAGMENTS ["instructions" ]:
104+ raise ValueError (f"Given instruction key '{ key } ' is not in fragments.json." )
105+ for key in self .cot_trigger_keys :
106+ if key is not None and key not in FRAGMENTS ["cot_triggers" ]:
107+ raise ValueError (f"Given cot_trigger key '{ key } ' is not in fragments.json." )
108+ for key in self .answer_extraction_keys :
109+ if key is not None and key not in FRAGMENTS ["answer_extractions" ]:
110+ raise ValueError (f"Given answer_extraction key '{ key } ' is not in fragments.json." )
111+
94112 # check if the templates contain only allowed keys
95113 import re
96114
@@ -115,15 +133,12 @@ def __post_init__(self):
115133 assert self .idx_range [0 ] < self .idx_range [1 ], "idx_range must be a tuple of ints with idx_range[0] < idx_range[1]"
116134
117135 if self .instruction_keys != "all" :
118- assert isinstance (self .instruction_keys , list ), "instruction_keys must be a list"
119136 assert all (isinstance (key , (str , type (None ))) for key in self .instruction_keys ), "instruction_keys must be a list of strings"
120137
121138 if self .cot_trigger_keys != "all" :
122- assert isinstance (self .cot_trigger_keys , list ), "cot_trigger_keys must be a list"
123139 assert all (isinstance (key , (str , type (None ))) for key in self .cot_trigger_keys ), "cot_trigger_keys must be a list of strings"
124140
125141 if self .answer_extraction_keys != "all" :
126- assert isinstance (self .answer_extraction_keys , list ), "answer_extraction_keys must be a list"
127142 assert all (
128143 isinstance (key , (str , type (None ))) for key in self .answer_extraction_keys
129144 ), "answer_extraction_keys must be a list of strings"
0 commit comments