Skip to content

Commit ee9e72f

Browse files
committed
Add fp16 support
1 parent 1a33ba2 commit ee9e72f

2 files changed

Lines changed: 14 additions & 5 deletions

File tree

slapo/verify.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def __exit__(self, *exc):
115115
is_pipeline = True
116116
else:
117117
is_pipeline = False
118+
if "dtype" in self.kwargs:
119+
logger.info("Using %s data type", self.kwargs["dtype"], ranks=0)
118120
# 1. Build the original model with random weights
119121
named_params = self.original_sch.mod.named_parameters()
120122
is_initialized = named_params.__next__()[1].device != torch.device("meta")
@@ -126,6 +128,8 @@ def __exit__(self, *exc):
126128
)
127129
# make sure all the buffers are on the right device
128130
original_mod = original_mod.to(self.device)
131+
# with the correct data type
132+
original_mod = original_mod.to(self.kwargs.get("dtype", torch.float32))
129133
# 2. Get the example inputs and outputs
130134
# Broadcast the example inputs from rank 0 in each TP/PP group
131135
# to other ranks in the same group.
@@ -253,7 +257,9 @@ def init_weights(mod, path):
253257
new_mod, _ = build(new_sch, init_weights=init_weights)
254258
# 8. Run the new model
255259
# make sure all the buffers are on the right device
256-
new_mod.to(self.device)
260+
new_mod = new_mod.to(self.device)
261+
# with the correct data type
262+
new_mod = new_mod.to(self.kwargs.get("dtype", torch.float32))
257263
if self.eval_mode:
258264
new_mod.eval()
259265
# make sure the random seeds are the same, which may affect the output of dropout
@@ -292,6 +298,7 @@ def init_weights(mod, path):
292298
# HF model may output shape-0 tensors for loss
293299
if self.loss_fn is not None and new_output.shape != original_output.shape:
294300
new_output = new_output.view(original_output.shape)
301+
new_output = new_output.to(original_output.dtype)
295302
torch.testing.assert_close(original_output, new_output)
296303
logger.info("Passed verification!")
297304
if not is_copy_failed:

tests/test_ds_pipeline.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,12 @@ def loss_fn(outputs, labels):
324324
ds_config_dict = get_ds_config(
325325
batch_size=bs,
326326
micro_batch_size_per_gpu=micro_bs,
327-
fp16=False,
327+
fp16=True,
328328
)
329329
device = "cuda"
330330
input_ids = torch.ones(micro_bs, seq_len, dtype=torch.long, device=device)
331331
attention_mask = torch.ones(
332-
micro_bs, seq_len, dtype=torch.float32, requires_grad=False, device=device
332+
micro_bs, seq_len, dtype=torch.float16, requires_grad=False, device=device
333333
)
334334
position_ids = torch.ones(
335335
micro_bs, seq_len, dtype=torch.long, requires_grad=False, device=device
@@ -346,6 +346,7 @@ def loss_fn(outputs, labels):
346346
topology=topology,
347347
config=ds_config_dict,
348348
init_weights=model._init_weights,
349+
dtype=torch.float16,
349350
):
350351
sch.trace_until(
351352
"transformer", tracer="huggingface", concrete_args=concrete_args
@@ -414,12 +415,12 @@ def loss_fn(outputs, labels):
414415
ds_config_dict = get_ds_config(
415416
batch_size=bs,
416417
micro_batch_size_per_gpu=micro_bs,
417-
fp16=False,
418+
fp16=True,
418419
)
419420
device = "cuda"
420421
input_ids = torch.ones(micro_bs, seq_len, dtype=torch.long, device=device)
421422
attention_mask = torch.ones(
422-
micro_bs, seq_len, dtype=torch.float32, requires_grad=False, device=device
423+
micro_bs, seq_len, dtype=torch.float16, requires_grad=False, device=device
423424
)
424425
position_ids = torch.ones(
425426
micro_bs, seq_len, dtype=torch.long, requires_grad=False, device=device
@@ -436,6 +437,7 @@ def loss_fn(outputs, labels):
436437
topology=topology,
437438
config=ds_config_dict,
438439
init_weights=model._init_weights,
440+
dtype=torch.float16,
439441
):
440442
sch.trace_until(
441443
"transformer", tracer="huggingface", concrete_args=concrete_args

0 commit comments

Comments
 (0)