@@ -324,12 +324,12 @@ def loss_fn(outputs, labels):
324324 ds_config_dict = get_ds_config (
325325 batch_size = bs ,
326326 micro_batch_size_per_gpu = micro_bs ,
327- fp16 = False ,
327+ fp16 = True ,
328328 )
329329 device = "cuda"
330330 input_ids = torch .ones (micro_bs , seq_len , dtype = torch .long , device = device )
331331 attention_mask = torch .ones (
332- micro_bs , seq_len , dtype = torch .float32 , requires_grad = False , device = device
332+ micro_bs , seq_len , dtype = torch .float16 , requires_grad = False , device = device
333333 )
334334 position_ids = torch .ones (
335335 micro_bs , seq_len , dtype = torch .long , requires_grad = False , device = device
@@ -346,6 +346,7 @@ def loss_fn(outputs, labels):
346346 topology = topology ,
347347 config = ds_config_dict ,
348348 init_weights = model ._init_weights ,
349+ dtype = torch .float16 ,
349350 ):
350351 sch .trace_until (
351352 "transformer" , tracer = "huggingface" , concrete_args = concrete_args
@@ -414,12 +415,12 @@ def loss_fn(outputs, labels):
414415 ds_config_dict = get_ds_config (
415416 batch_size = bs ,
416417 micro_batch_size_per_gpu = micro_bs ,
417- fp16 = False ,
418+ fp16 = True ,
418419 )
419420 device = "cuda"
420421 input_ids = torch .ones (micro_bs , seq_len , dtype = torch .long , device = device )
421422 attention_mask = torch .ones (
422- micro_bs , seq_len , dtype = torch .float32 , requires_grad = False , device = device
423+ micro_bs , seq_len , dtype = torch .float16 , requires_grad = False , device = device
423424 )
424425 position_ids = torch .ones (
425426 micro_bs , seq_len , dtype = torch .long , requires_grad = False , device = device
@@ -436,6 +437,7 @@ def loss_fn(outputs, labels):
436437 topology = topology ,
437438 config = ds_config_dict ,
438439 init_weights = model ._init_weights ,
440+ dtype = torch .float16 ,
439441 ):
440442 sch .trace_until (
441443 "transformer" , tracer = "huggingface" , concrete_args = concrete_args
0 commit comments