Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions msamp/megatron/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_a
weight_fp8.requires_grad = weight.requires_grad

# save tensors
ctx.input_fp8 = input_fp8
ctx.weight_fp8 = weight_fp8
ctx.input_fp8_sf = input_fp8.meta
ctx.weight_fp8_sf = weight_fp8.meta
ctx.save_for_backward(input_fp8.value.view(dtype=torch.float16), weight_fp8.value.view(dtype=torch.float16))
ctx.weight = weight

dim_size = list(input.size())
Expand Down Expand Up @@ -100,8 +101,9 @@ def backward(ctx, grad_output):
Returns:
A tuple of gradients of the arguments.
"""
input_fp8 = ctx.input_fp8
weight_fp8 = ctx.weight_fp8
input_fp8_fp16, weight_fp8_fp16 = ctx.saved_tensors
input_fp8 = ScalingTensor(input_fp8_fp16.view(dtype=torch.uint8), meta=ctx.input_fp8_sf)
weight_fp8 = ScalingTensor(weight_fp8_fp16.view(dtype=torch.uint8), meta=ctx.weight_fp8_sf)
input = input_fp8.value
output_qtype = ctx.output_qtype
metas = ctx.metas
Expand Down
17 changes: 11 additions & 6 deletions msamp/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ def forward(ctx, input, weight, metas, dtype_holder):
input_fp8 = input.cast(Dtypes.kfloat8_e4m3, meta=input_meta)
weight_fp8 = weight.cast(Dtypes.kfloat8_e4m3)

ctx.input_fp8 = input_fp8
ctx.input_fp8.requires_grad = input.requires_grad
ctx.weight_fp8 = weight_fp8
ctx.input_fp8_sf = input_fp8.meta
ctx.weight_fp8_sf = weight_fp8.meta
ctx.save_for_backward(input_fp8.value.view(dtype=torch.float16), weight_fp8.value.view(dtype=torch.float16))
ctx.input_fp8_requires_grad = input.requires_grad
ctx.weight = weight

output_dtype = dtype_holder.dtype
Expand Down Expand Up @@ -80,14 +81,18 @@ def backward(ctx, output_grad):
wgrad_meta = metas['wgrad']
ograd_fp8, ograd_fp8_t = output_grad.fused_cast_transpose(Dtypes.kfloat8_e5m2, meta=ograd_meta)

if ctx.input_fp8.requires_grad:
weight_fp8_t = ctx.weight_fp8.fp8_transpose()
input_fp8_fp16, weight_fp8_fp16 = ctx.saved_tensors
input_fp8 = ScalingTensor(input_fp8_fp16.view(dtype=torch.uint8), meta=ctx.input_fp8_sf)
weight_fp8 = ScalingTensor(weight_fp8_fp16.view(dtype=torch.uint8), meta=ctx.weight_fp8_sf)

if ctx.input_fp8_requires_grad:
weight_fp8_t = weight_fp8.fp8_transpose()
input_grad = Gemm.fp8_gemm(weight_fp8_t, ograd_fp8, ctx.output_qtype, use_split_accumulator=True)
else:
input_grad = None

if ctx.weight.requires_grad:
input_fp8_t = ctx.input_fp8.fp8_transpose()
input_fp8_t = input_fp8.fp8_transpose()
wgrad_qtype = ctx.output_qtype
# compute weight gradient
if ctx.weight.grad is None:
Expand Down
Loading