Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/phi/backends/gpu/rocm/miopen_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ inline miopenDataType_t ToCudnnDataType(const DataType& t) {
case DataType::FLOAT32:
type = miopenFloat;
break;
case DataType::BFLOAT16:
type = miopenBFloat16;
break;
default:
break;
}
Expand Down
9 changes: 7 additions & 2 deletions paddle/phi/kernels/gpu/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,13 @@ template PADDLE_API void LayerNormKernel<double, GPUContext>(

#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(
layer_norm, GPU, ALL_LAYOUT, phi::LayerNormKernel, float, phi::float16) {
PD_REGISTER_KERNEL(layer_norm,
GPU,
ALL_LAYOUT,
phi::LayerNormKernel,
float,
phi::float16,
phi::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/gpudnn/conv_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,8 @@ PD_REGISTER_KERNEL(conv2d_grad,
ALL_LAYOUT,
phi::ConvCudnnGradKernel,
float,
phi::float16) {}
phi::float16,
phi::bfloat16) {}

PD_REGISTER_KERNEL(conv3d_grad,
GPUDNN,
Expand All @@ -1456,7 +1457,8 @@ PD_REGISTER_KERNEL(conv2d_double_grad,
ALL_LAYOUT,
phi::ConvCudnnGradGradKernel,
float,
phi::float16) {}
phi::float16,
phi::bfloat16) {}

PD_REGISTER_KERNEL(conv3d_double_grad,
GPUDNN,
Expand Down
9 changes: 7 additions & 2 deletions paddle/phi/kernels/gpudnn/conv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,13 @@ void Conv3DCudnnKernel(const Context& dev_ctx,
} // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(
conv2d, GPUDNN, ALL_LAYOUT, phi::ConvCudnnKernel, float, phi::float16) {}
PD_REGISTER_KERNEL(conv2d,
GPUDNN,
ALL_LAYOUT,
phi::ConvCudnnKernel,
float,
phi::float16,
phi::bfloat16) {}

PD_REGISTER_KERNEL(
conv3d, GPUDNN, ALL_LAYOUT, phi::Conv3DCudnnKernel, float, phi::float16) {}
Expand Down
66 changes: 55 additions & 11 deletions test/legacy_test/test_conv2d_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def _get_padding_with_SAME(input_shape, pool_size, pool_stride):

def create_test_cudnn_class(parent):
@unittest.skipIf(
not (core.is_compiled_with_cuda() or is_custom_device()),
not (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
),
"core is not compiled with CUDA",
)
class TestCUDNNCase(parent):
Expand All @@ -171,7 +175,11 @@ def init_kernel_type(self):

def create_test_cudnn_fp16_class(parent, grad_check=True):
@unittest.skipIf(
not (core.is_compiled_with_cuda() or is_custom_device()),
not (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
),
"core is not compiled with CUDA",
)
class TestConv2DCUDNNFp16(parent):
Expand Down Expand Up @@ -206,7 +214,11 @@ def test_check_grad_no_input(self):

def create_test_cudnn_bf16_class(parent):
@unittest.skipIf(
not (core.is_compiled_with_cuda() or is_custom_device())
not (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
)
or not core.is_bfloat16_supported(get_device_place()),
"core is not compiled with CUDA and do not support bfloat16",
)
Expand Down Expand Up @@ -273,15 +285,23 @@ def init_test_case_2(self):

def create_test_cudnn_channel_last_class(parent):
@unittest.skipIf(
not (core.is_compiled_with_cuda() or is_custom_device()),
not (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
),
"core is not compiled with CUDA",
)
class TestCudnnChannelLastCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = (
np.float32
if (core.is_compiled_with_rocm() or is_custom_device())
if (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
)
else np.float64
)

Expand All @@ -299,7 +319,11 @@ def init_test_case_2(self):

def create_test_cudnn_channel_last_fp16_class(parent, grad_check=True):
@unittest.skipIf(
not (core.is_compiled_with_cuda() or is_custom_device()),
not (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
),
"core is not compiled with CUDA",
)
class TestCudnnChannelLastFp16(parent):
Expand All @@ -308,7 +332,11 @@ def init_kernel_type(self):
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda() or is_custom_device():
if (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
):
place = get_device_place()
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
Expand Down Expand Up @@ -363,15 +391,23 @@ def init_paddings(self):

def create_test_cudnn_padding_SAME_class(parent):
@unittest.skipIf(
not (core.is_compiled_with_cuda() or is_custom_device()),
not (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
),
"core is not compiled with CUDA",
)
class TestCUDNNPaddingSAMECase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = (
np.float32
if (core.is_compiled_with_rocm() or is_custom_device())
if (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
)
else np.float64
)

Expand All @@ -386,15 +422,23 @@ def init_paddings(self):

def create_test_cudnn_padding_VALID_class(parent):
@unittest.skipIf(
not (core.is_compiled_with_cuda() or is_custom_device()),
not (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
),
"core is not compiled with CUDA",
)
class TestCUDNNPaddingVALIDCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = (
np.float32
if (core.is_compiled_with_rocm() or is_custom_device())
if (
core.is_compiled_with_cuda()
or core.is_compiled_with_rocm()
or is_custom_device()
)
else np.float64
)

Expand Down
Loading