diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index eb33ea76613cf..11b189b27465a 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -3595,7 +3595,8 @@ class OpenMPIRBuilder { InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)> PrivAndGenMapInfoCB, llvm::Type *ElemTy, StringRef FuncName, - CustomMapperCallbackTy CustomMapperCB); + CustomMapperCallbackTy CustomMapperCB, + bool PreserveMemberOfFlags = false); /// Generator for '#omp target data' /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index ee98756820e07..3645f8fd98390 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -10272,7 +10272,8 @@ Expected OpenMPIRBuilder::emitUserDefinedMapper( function_ref GenMapInfoCB, - Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) { + Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB, + bool PreserveMemberOfFlags) { SmallVector Params; Params.emplace_back(Builder.getPtrTy()); Params.emplace_back(Builder.getPtrTy()); @@ -10369,8 +10370,21 @@ Expected OpenMPIRBuilder::emitUserDefinedMapper( Value *OriMapType = Builder.getInt64( static_cast>( Info->Types[I])); - Value *MemberMapType = - Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize); + Value *MemberMapType; + if (PreserveMemberOfFlags) { + constexpr uint64_t MemberOfMask = + static_cast(OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF); + uint64_t OrigFlags = + static_cast>( + Info->Types[I]); + bool HasMemberOf = (OrigFlags & MemberOfMask) != 0; + if (HasMemberOf) + MemberMapType = Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize); + else + MemberMapType = OriMapType; + } else { + MemberMapType = Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize); + } // Combine the map type inherited from user-defined mapper with that // specified in the program. According to the OMP_MAP_TO and OMP_MAP_FROM diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index a56467581b0d7..698d958b4ab2b 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5889,16 +5889,8 @@ processIndividualMap(llvm::IRBuilderBase &builder, // the host, and then expect it to not be updated in a subsequent impliict map // (such as an implicit map on a target). if (memberOfFlag != llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE) { - // If we are in a declare mapper, we apply MEMBER_OF even if it's an attach - // or pointer map, this is to make the MEMBER_OF flag uniform across all - // maps within the declare mapper, as even if we do not apply it here on - // nestings greater than the first layer we will have a member of flag - // applied automatically. So, we canonicalize it here, which keeps the - // behaviour of pointer/data maps consistent across layers. - if ((!isPtrTy && !isAttachMap) || - mapInfoOp->getParentOfType()) { + if (!isPtrTy && !isAttachMap) ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); - } // The return parameter should be the over-riding parent in cases where we // have a return parameter that is echoed to all members, the main case of @@ -5915,8 +5907,7 @@ processIndividualMap(llvm::IRBuilderBase &builder, // map-backs in certain cases where an implicit declare mapepr has been // emitted for a target region. Applying MAP_PTR_AND_OBJ in these situations // circumvents this. - if (isPtrTy && !isAttachMap && (mapData.IsDeclareTarget[mapDataIdx] || - mapInfoOp->getParentOfType())) + if (isPtrTy && !isAttachMap && mapData.IsDeclareTarget[mapDataIdx]) mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ; // if we're provided a mapDataParentIdx, then the data being mapped is @@ -6534,7 +6525,8 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, }; llvm::Expected newFn = ompBuilder->emitUserDefinedMapper( - genMapInfoCB, varType, mapperFuncName, customMapperCB); + genMapInfoCB, varType, mapperFuncName, customMapperCB, + /*PreserveMemberOfFlags=*/true); if (!newFn) return newFn.takeError(); if ([[maybe_unused]] llvm::Function *mappedFunc = diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir index a14dd2e2f5ae2..b80e6220e6646 100644 --- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir @@ -582,27 +582,23 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} { // CHECK: %[[VAL_45:.*]] = getelementptr %[[VAL_18]], ptr %[[VAL_43]], i32 0, i32 0 // CHECK: %[[VAL_46:.*]] = call i64 @__tgt_mapper_num_components(ptr %[[VAL_37]]) // CHECK: %[[VAL_47:.*]] = shl i64 %[[VAL_46]], 48 -// CHECK: %[[VAL_48:.*]] = add nuw i64 3, %[[VAL_47]] // CHECK: %[[VAL_49:.*]] = and i64 %[[VAL_22]], 3 // CHECK: %[[VAL_50:.*]] = icmp eq i64 %[[VAL_49]], 0 // CHECK: br i1 %[[VAL_50]], label %[[VAL_51:.*]], label %[[VAL_52:.*]] // CHECK: omp.type.alloc: ; preds = %[[VAL_41]] -// CHECK: %[[VAL_53:.*]] = and i64 %[[VAL_48]], -4 // CHECK: br label %[[VAL_42]] // CHECK: omp.type.alloc.else: ; preds = %[[VAL_41]] // CHECK: %[[VAL_54:.*]] = icmp eq i64 %[[VAL_49]], 1 // CHECK: br i1 %[[VAL_54]], label %[[VAL_55:.*]], label %[[VAL_56:.*]] // CHECK: omp.type.to: ; preds = %[[VAL_52]] -// CHECK: %[[VAL_57:.*]] = and i64 %[[VAL_48]], -3 // CHECK: br label %[[VAL_42]] // CHECK: omp.type.to.else: ; preds = %[[VAL_52]] // CHECK: %[[VAL_58:.*]] = icmp eq i64 %[[VAL_49]], 2 // CHECK: br i1 %[[VAL_58]], label %[[VAL_59:.*]], label %[[VAL_42]] // CHECK: omp.type.from: ; preds = %[[VAL_56]] -// CHECK: %[[VAL_60:.*]] = and i64 %[[VAL_48]], -2 // CHECK: br label %[[VAL_42]] // CHECK: omp.type.end: ; preds = %[[VAL_59]], %[[VAL_56]], %[[VAL_55]], %[[VAL_51]] -// CHECK: %[[VAL_61:.*]] = phi i64 [ %[[VAL_53]], %[[VAL_51]] ], [ %[[VAL_57]], %[[VAL_55]] ], [ %[[VAL_60]], %[[VAL_59]] ], [ %[[VAL_48]], %[[VAL_56]] ] +// CHECK: %[[VAL_61:.*]] = phi i64 [ 0, %[[VAL_51]] ], [ 1, %[[VAL_55]] ], [ 2, %[[VAL_59]] ], [ 3, %[[VAL_56]] ] // CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37]], ptr %[[VAL_45]], ptr %[[VAL_45]], i64 4, i64 %[[VAL_61]], ptr @2) // CHECK: %[[VAL_44]] = getelementptr %[[VAL_18]], ptr %[[VAL_43]], i32 1 // CHECK: %[[VAL_62:.*]] = icmp eq ptr %[[VAL_44]], %[[VAL_17]]