diff --git a/msamp/megatron/layers.py b/msamp/megatron/layers.py index d28d0c2a..e8a7fa97 100644 --- a/msamp/megatron/layers.py +++ b/msamp/megatron/layers.py @@ -60,8 +60,9 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_a weight_fp8.requires_grad = weight.requires_grad # save tensors - ctx.input_fp8 = input_fp8 - ctx.weight_fp8 = weight_fp8 + ctx.input_fp8_sf = input_fp8.meta + ctx.weight_fp8_sf = weight_fp8.meta + ctx.save_for_backward(input_fp8.value.view(dtype=torch.float16), weight_fp8.value.view(dtype=torch.float16)) ctx.weight = weight dim_size = list(input.size()) @@ -100,8 +101,9 @@ def backward(ctx, grad_output): Returns: A tuple of gradients of the arguments. """ - input_fp8 = ctx.input_fp8 - weight_fp8 = ctx.weight_fp8 + input_fp8_fp16, weight_fp8_fp16 = ctx.saved_tensors + input_fp8 = ScalingTensor(input_fp8_fp16.view(dtype=torch.uint8), meta=ctx.input_fp8_sf) + weight_fp8 = ScalingTensor(weight_fp8_fp16.view(dtype=torch.uint8), meta=ctx.weight_fp8_sf) input = input_fp8.value output_qtype = ctx.output_qtype metas = ctx.metas diff --git a/msamp/nn/functional.py b/msamp/nn/functional.py index 23d023cb..6912ab9d 100644 --- a/msamp/nn/functional.py +++ b/msamp/nn/functional.py @@ -45,9 +45,10 @@ def forward(ctx, input, weight, metas, dtype_holder): input_fp8 = input.cast(Dtypes.kfloat8_e4m3, meta=input_meta) weight_fp8 = weight.cast(Dtypes.kfloat8_e4m3) - ctx.input_fp8 = input_fp8 - ctx.input_fp8.requires_grad = input.requires_grad - ctx.weight_fp8 = weight_fp8 + ctx.input_fp8_sf = input_fp8.meta + ctx.weight_fp8_sf = weight_fp8.meta + ctx.save_for_backward(input_fp8.value.view(dtype=torch.float16), weight_fp8.value.view(dtype=torch.float16)) + ctx.input_fp8_requires_grad = input.requires_grad ctx.weight = weight output_dtype = dtype_holder.dtype @@ -80,14 +81,18 @@ def backward(ctx, output_grad): wgrad_meta = metas['wgrad'] ograd_fp8, ograd_fp8_t = output_grad.fused_cast_transpose(Dtypes.kfloat8_e5m2, meta=ograd_meta) - if ctx.input_fp8.requires_grad: - weight_fp8_t = ctx.weight_fp8.fp8_transpose() + input_fp8_fp16, weight_fp8_fp16 = ctx.saved_tensors + input_fp8 = ScalingTensor(input_fp8_fp16.view(dtype=torch.uint8), meta=ctx.input_fp8_sf) + weight_fp8 = ScalingTensor(weight_fp8_fp16.view(dtype=torch.uint8), meta=ctx.weight_fp8_sf) + + if ctx.input_fp8_requires_grad: + weight_fp8_t = weight_fp8.fp8_transpose() input_grad = Gemm.fp8_gemm(weight_fp8_t, ograd_fp8, ctx.output_qtype, use_split_accumulator=True) else: input_grad = None if ctx.weight.requires_grad: - input_fp8_t = ctx.input_fp8.fp8_transpose() + input_fp8_t = input_fp8.fp8_transpose() wgrad_qtype = ctx.output_qtype # compute weight gradient if ctx.weight.grad is None: