Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ class AttentionCPUBase : public AttentionBase {
OpKernelContext* context,
int beam_width,
Tensor* output_qk) const {
ORT_RETURN_IF_ERROR(ValidateCacheIndirectionValues(cache_indir->Data<int32_t>(), batch_size, beam_width,
past_sequence_length, max_sequence_length));

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

Expand Down Expand Up @@ -186,6 +189,33 @@ class AttentionCPUBase : public AttentionBase {
}

private:
static Status ValidateCacheIndirectionValues(const int32_t* cache_indirection_data,
int batch_beam_size,
int beam_width,
int past_sequence_length,
int max_sequence_length) {
if (cache_indirection_data == nullptr || beam_width <= 0 || past_sequence_length <= 0) {
return Status::OK();
}

for (int batch_beam_index = 0; batch_beam_index < batch_beam_size; ++batch_beam_index) {
const int32_t* beam_indices = cache_indirection_data +
static_cast<std::ptrdiff_t>(batch_beam_index) * max_sequence_length;
for (int position = 0; position < past_sequence_length; ++position) {
const int32_t beam_index = beam_indices[position];
if (beam_index < 0 || beam_index >= beam_width) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cache_indirection beam index out of range. Expected [0, ", beam_width,
"), got ", beam_index,
" at flattened batch_beam index ", batch_beam_index,
", sequence position ", position);
}
}
}

return Status::OK();
}

// Helper function to compute the attention probs. It does 2 things:
// attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) +
// 1 x mask_data(B, N, S, T)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,20 @@ Status DecoderMaskedMultiHeadAttention<T>::Compute(OpKernelContext* context) con
"If beam width is greater than 1, then cache indirection buffer MUST be present");
}

if (cache_indir != nullptr) {
// Read beam width from cache_indirection shape directly.
// DecoderMaskedMultiHeadAttentionParameters shadows AttentionParameters::beam_width,
// so the value set by CheckInputs on the base class is not visible here.
int cache_beam_width = static_cast<int>(cache_indir->Shape().GetDims()[1]);
if (beam_width != nullptr && beam_width_value != cache_beam_width) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'beam_width' should match cache_indirection dimension 1, got ",
beam_width_value, " and ", cache_beam_width);
}

beam_width_value = cache_beam_width;
}

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

Expand Down
20 changes: 16 additions & 4 deletions onnxruntime/contrib_ops/cpu/maxpool_with_mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,22 @@ class MaxpoolWithMask : public OpKernel, public PoolBase {
const TensorShape& x_shape = X->Shape();
const TensorShape& m_shape = M->Shape();
ORT_RETURN_IF_NOT(x_shape.NumDimensions() >= 3, "Input dimension cannot be less than 3.");

// TODO: fix this checker later
// ONNXRUNTIME_RETURN_IF_NOT((x_shape[2] == m_shape[2]) && (x_shape[3] == m_shape[3]), " Input shape and mask shape
// mismatch: ", x_shape, " vs ", m_shape);
ORT_RETURN_IF_NOT(m_shape.NumDimensions() == x_shape.NumDimensions(),
"Mask and input must have the same number of dimensions. Got mask dims: ",
m_shape.NumDimensions(), " input dims: ", x_shape.NumDimensions());
const bool input_has_nonzero_channels = x_shape[0] > 0 && x_shape[1] > 0;
// Mask N and C dimensions may differ from input (broadcasting via modulo).
// Only require them to be nonzero to prevent division-by-zero in total_mask_channels.
ORT_RETURN_IF_NOT(!input_has_nonzero_channels || (m_shape[0] > 0 && m_shape[1] > 0),
"Mask N and C dimensions must be greater than 0 when input N and C are greater than 0. "
"Got mask N=",
m_shape[0], " C=", m_shape[1],
" input N=", x_shape[0], " C=", x_shape[1]);
for (size_t i = 2; i < x_shape.NumDimensions(); ++i) {
ORT_RETURN_IF_NOT(m_shape[i] == x_shape[i],
"Mask and input spatial dimensions mismatch at dimension ", i,
": mask=", m_shape[i], " input=", x_shape[i]);
}

TensorShapeVector pads = pool_attrs_.pads;
TensorShapeVector kernel_shape = pool_attrs_.kernel_shape;
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_),
WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_),
WGSL_TEMPLATE_PARAMETER(is_apple, is_apple_),
WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_),
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
Expand All @@ -221,7 +220,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_),
WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_),
WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_),
WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_));
WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_),
WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_));
}

Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const {
Expand Down Expand Up @@ -486,6 +486,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
bool is_apple = context.AdapterInfo().vendor == std::string_view{"apple"};
bool has_subgroups = context.HasFeature(wgpu::FeatureName::Subgroups);
bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH;
bool has_head_sink = head_sink != nullptr;
Expand All @@ -498,6 +499,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
parameters.is_unidirectional_,
is_nvidia,
is_apple,
has_subgroups,
q_BNSH,
use_seqlen_k,
has_head_sink};
Expand Down Expand Up @@ -532,7 +534,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co

program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile)
.SetWorkgroupSize(prefill_tile_size)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, is_apple, q_BNSH, use_seqlen_k, has_head_sink, program.max_k_step())
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, is_apple, has_subgroups, q_BNSH, use_seqlen_k, has_head_sink, program.max_k_step())
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(parameters.total_sequence_length_)},
{static_cast<uint32_t>(present_sequence_length)},
Expand Down Expand Up @@ -584,7 +586,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
return !parameters.is_packed_qkv_ &&
parameters.head_size_ == parameters.v_head_size_ &&
context.HasFeature(wgpu::FeatureName::Subgroups) &&
((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0);
}

Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
bool is_unidirectional,
bool is_nvidia,
bool is_apple,
bool has_subgroups,
bool q_BNSH,
bool use_seqlen_k = false,
bool has_head_sink = false)
Expand All @@ -88,12 +89,12 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
qkv_num_heads_(qkv_num_heads),
is_unidirectional_(is_unidirectional),
is_nvidia_(is_nvidia),
is_apple_(is_apple),
use_shm_path_(is_apple || is_nvidia || !has_subgroups),
q_BNSH_(q_BNSH),
use_seqlen_k_(use_seqlen_k),
has_head_sink_(has_head_sink) {
if (is_apple || is_nvidia) {
// On Apple and NVIDIA, use an optimized loop-based path with dynamic max_k_step.
if (use_shm_path_) {
// Use shared-memory loop-based path with dynamic max_k_step.
// Compute max_k_step from workgroup shared memory budget: k_tile + v_tile = 2 * element_size * head_size * max_k_step
const int element_size = is_fp16 ? 2 : 4;
constexpr int kMinWorkgroupStorageBudgetBytes = 16384;
Expand Down Expand Up @@ -130,7 +131,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
int qkv_num_heads_;
bool is_unidirectional_;
bool is_nvidia_;
bool is_apple_;
bool use_shm_path_;
bool q_BNSH_;
bool use_seqlen_k_;
bool has_head_sink_;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

#param has_attention_bias
#param has_head_sink
#param is_apple
#param is_fp16
#param is_qualcomm
#param is_unidirectional
Expand All @@ -10,6 +9,7 @@
#param qkv_head_size
#param qkv_num_heads
#param use_seqlen_k
#param use_shm_path
#param max_k_step_param

const head_size : u32 = qkv_head_size;
Expand Down Expand Up @@ -61,7 +61,7 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_
}
}

#if is_apple
#if use_shm_path

var<private> qk_scores : array<q_element_t, max_k_step>;

Expand Down Expand Up @@ -240,7 +240,7 @@ $MAIN {
let seq_causal_length = total_sequence_length;
#endif

#if is_apple
#if use_shm_path

for (var k_start = 0u; k_start < loop_bound; k_start += max_k_step) {
workgroupBarrier();
Expand Down
11 changes: 10 additions & 1 deletion onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1007,8 +1007,17 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
const epctx::BufferWriteFuncHolder* output_write_func_holder = ep_context_gen_options.TryGetOutputModelWriteFunc();
const std::filesystem::path* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath();

// Determine whether we need to resolve/validate a file system path for the output model.
// A path is needed when:
// - Writing the output model to a file (not to a buffer or write function)
// - Writing initializers to an external file (needs the model path to compute the external file location)
const bool output_is_to_file = (output_buffer_holder == nullptr && output_write_func_holder == nullptr);
const bool needs_path_for_external_initializers =
(ep_context_gen_options.TryGetExternalInitializerFileInfo() != nullptr);

std::filesystem::path valid_output_model_path;
if (output_model_path_ptr != nullptr || !graph.ModelPath().empty()) {
if ((output_is_to_file || needs_path_for_external_initializers) &&
(output_model_path_ptr != nullptr || !graph.ModelPath().empty())) {
std::filesystem::path output_model_path = (output_model_path_ptr != nullptr) ? *output_model_path_ptr
: std::filesystem::path("");
ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(output_model_path,
Expand Down
Loading
Loading