2929"""
3030
3131import argparse
32+ import gc
3233import re
3334import sys
3435from pathlib import Path
@@ -104,30 +105,44 @@ def validate_model(model_name: str,
104105 allowed : set [str ],
105106 forbidden : set [str ],
106107 verbose : bool ,
107- quantize : bool = False ) -> bool :
108- """Validate one HuggingFace model. Returns True if all ops pass."""
108+ quantize : bool = False ,
109+ auto_class : str | None = None ,
110+ config_overrides : dict | None = None ) -> str :
111+ """Validate one HuggingFace model.
112+
113+ Returns "pass", "fail" (op validation failed), or "skip" (could not
114+ load/trace — e.g. private model without HF_TOKEN).
115+ """
109116 label = f"{ model_name } (quantized)" if quantize else model_name
110117 print (f" { label } ..." , file = sys .stderr )
111- traced = load_and_trace_hf_model (model_name , quantize = quantize )
118+ traced = load_and_trace_hf_model (model_name , quantize = quantize ,
119+ auto_class = auto_class ,
120+ config_overrides = config_overrides )
112121 if traced is None :
113- print (f" FAILED (could not load/trace)" , file = sys .stderr )
114- return False
122+ print (f" SKIPPED (could not load/trace)" , file = sys .stderr )
123+ return "skip"
115124 ops = collect_inlined_ops (traced )
116- return check_ops (ops , allowed , forbidden , verbose )
125+ result = "pass" if check_ops (ops , allowed , forbidden , verbose ) else "fail"
126+ del traced
127+ gc .collect ()
128+ return result
117129
118130
119131def validate_pt_file (name : str ,
120132 pt_path : str ,
121133 allowed : set [str ],
122134 forbidden : set [str ],
123- verbose : bool ) -> bool :
124- """Validate a local TorchScript .pt file. Returns True if all ops pass."""
135+ verbose : bool ) -> str :
136+ """Validate a local TorchScript .pt file.
137+
138+ Returns "pass", "fail", or "skip".
139+ """
125140 print (f" { name } ({ pt_path } )..." , file = sys .stderr )
126141 ops = load_pt_and_collect_ops (pt_path )
127142 if ops is None :
128- print (f" FAILED (could not load)" , file = sys .stderr )
129- return False
130- return check_ops (ops , allowed , forbidden , verbose )
143+ print (f" SKIPPED (could not load)" , file = sys .stderr )
144+ return "skip"
145+ return "pass" if check_ops (ops , allowed , forbidden , verbose ) else "fail"
131146
132147
133148def main ():
@@ -151,7 +166,7 @@ def main():
151166 print (f"Parsed { len (allowed )} allowed ops and { len (forbidden )} "
152167 f"forbidden ops from { SUPPORTED_OPS_CC .name } " , file = sys .stderr )
153168
154- results : dict [str , bool ] = {}
169+ results : dict [str , str ] = {}
155170
156171 models = load_model_config (args .config )
157172
@@ -161,7 +176,9 @@ def main():
161176 for arch , spec in models .items ():
162177 results [arch ] = validate_model (
163178 spec ["model_id" ], allowed , forbidden , args .verbose ,
164- quantize = spec ["quantized" ])
179+ quantize = spec ["quantized" ],
180+ auto_class = spec .get ("auto_class" ),
181+ config_overrides = spec .get ("config_overrides" ))
165182
166183 if args .pt_dir and args .pt_dir .is_dir ():
167184 pt_files = sorted (args .pt_dir .glob ("*.pt" ))
@@ -175,26 +192,32 @@ def main():
175192
176193 print (file = sys .stderr )
177194 print ("=" * 60 , file = sys .stderr )
178- all_pass = all (results .values ())
179- for key , passed in results .items ():
180- status = "PASS" if passed else "FAIL"
195+ for key , status in results .items ():
196+ display = status .upper ()
181197 if key .startswith ("pt:" ):
182- print (f" { key } : { status } " , file = sys .stderr )
198+ print (f" { key } : { display } " , file = sys .stderr )
183199 else :
184200 spec = models [key ]
185201 label = spec ["model_id" ]
186202 if spec ["quantized" ]:
187203 label += " (quantized)"
188- print (f" { key } ({ label } ): { status } " , file = sys .stderr )
204+ print (f" { key } ({ label } ): { display } " , file = sys .stderr )
205+
206+ failed = [a for a , s in results .items () if s == "fail" ]
207+ skipped = [a for a , s in results .items () if s == "skip" ]
208+ passed = [a for a , s in results .items () if s == "pass" ]
189209
190210 print ("=" * 60 , file = sys .stderr )
191- if all_pass :
192- print ("All models PASS - no false positives." , file = sys .stderr )
193- else :
194- failed = [a for a , p in results .items () if not p ]
195- print (f"FAILED models: { ', ' .join (failed )} " , file = sys .stderr )
211+ print (f"{ len (passed )} passed, { len (failed )} failed, "
212+ f"{ len (skipped )} skipped" , file = sys .stderr )
213+
214+ if skipped :
215+ print (f"Skipped (could not load/trace — may need HF_TOKEN "
216+ f"for private models): { ', ' .join (skipped )} " , file = sys .stderr )
217+ if failed :
218+ print (f"FAILED (op validation): { ', ' .join (failed )} " , file = sys .stderr )
196219
197- sys .exit (0 if all_pass else 1 )
220+ sys .exit (0 if not failed else 1 )
198221
199222
200223if __name__ == "__main__" :
0 commit comments