Skip to content

aarch64-sme #29

@lfeng14

Description

@lfeng14
  • example

    #include <stdio.h>
    #include <arm_sme.h>
    
    int func1(void) {
        // func1();
        svzero_za(0);
        printf("SME 状态切换演示完成!\n");
        printf("SME 状态切换演示完成!\n");
        return 0;
    }
    
    __arm_new("za") int main(void) __arm_streaming {
        func1();
        printf("SME 状态切换演示完成!\n");
        printf("SME 状态切换演示完成!\n");
        return 0;
    }
    
    arm_streaming 调用print:
    	smstop	sm
    	bl	func1
    	smstart	sm
    arm_streaming 调用 arm_locally_streaming:
    	smstop	sm
    	bl	func1
    	smstart	sm
    arm_streaming 调用 arm_streaming(避免频繁切换):
    	bl	func
    
  • 现在理解了sme streaming模式下的函数属性用法,比如arm_streaming arm_locally_streaming 不加stream属性,主要还是想尽量少插入sme指令避免运行时开销;而又保证流模式功能正常;那我现在za也有同样的疑问,有哪些属性分别哪种场景使用

    • __arm_new_za:独立战场的“新领地”

      • 语义:该函数需要使用 ZA,且不依赖调用者(Caller)的任何矩阵数据。它会创建一个私有的、全新的 ZA 上下文。
      • 编译器行为
        1. 保存:如果 Caller 正在使用 ZA,编译器将其压栈。
        2. 开启与初始化:执行 smstart za,并自动执行 svzero_za(清零)。
        3. 恢复:函数返回前,执行 smstop za 并从栈中恢复 Caller 的数据。
      • 使用场景算子调用的最顶层入口
      • 例子:AlphaFold3 的整个 Attention 算子开始处。你不需要之前的任何矩阵状态,只想从零开始累加。
    • __arm_shared_za:最高效的“协作模式”

      • 语义:该函数与调用者共享同一个 ZA 阵列。它既可能读取 Caller 算好的数据,也可能修改它供 Caller 之后使用。
      • 编译器行为零开销。编译器假设 ZA 已经开启,直接生成计算指令。
      • 使用场景微内核(Micro-kernel)与中间逻辑
      • 例子:你写了一个专门负责“矩阵转置累加”的子函数。它应该被标为 shared_za,这样主算子在循环调用它时,不需要任何切换开销。
    • __arm_preserves_za:礼貌的“借用者”

      • 语义:该函数可能会用到 ZA,但它保证在返回时,ZA 的内容原封不动
      • 编译器行为:编译器允许函数修改 ZA,但必须在 ret 之前负责把动过的地方改回来(通常涉及局部备份到寄存器或栈)。
      • 使用场景插入式的监控或辅助函数
        • 例子:在矩阵运算中间插入一个 Debug 函数,打印 ZA 的某个切片,但不希望影响后续的计算流。
    • __arm_in_za, __arm_out_za, __arm_inout_za (具体流向控制)
      在最新的 ACLE 中,这些是 __arm_shared_za 的细化版本,用于告诉编译器数据是怎么流动的:

      • __arm_in_za:只读 Caller 的 ZA。
      • __arm_out_za:不看旧的,只负责写出新的给 Caller。
      • __arm_inout_za:既读又写(类似 C 语言的 += 操作)。
      • 意义:这给了 LLVM 优化器(如死代码消除、指令重排)极大的空间。如果编译器知道你是 out_za,它可能会提前释放 Caller 的旧数据。
    • 开销小的使用方式:

      • 塔尖(Entrance):使用 __arm_locally_streaming + __arm_new_za
        • 作用:负责与非 SME 环境对接,承担唯一的 smstart 开销。
      • 塔身(Middle-ware):使用 __arm_streaming + __arm_shared_za
        • 作用:串联不同的计算步骤(如先做 MOPA,再做 AddVA),保持硬件状态稳定。
      • 塔基(Micro-kernel):使用 __arm_streaming + __arm_shared_za + static inline
        • 作用:最内层的 FMOPA 指令块,确保没有任何非计算指令。
  • 为什么这么构建没有问题,in表示只读不写:

    #include <stdio.h>
    #include <arm_sme.h>
    
    int func1(void) __arm_in("za") {
        // func1();
        svzero_za(0);
        printf("SME 状态切换演示完成!\n");
        printf("SME 状态切换演示完成!\n");
        return 0;
    }
    
    __arm_new("za") int main(void) __arm_streaming {
        func1();
        printf("SME 状态切换演示完成!\n");
        printf("SME 状态切换演示完成!\n");
          return 0;
      }
    
  • 奇怪的事,不报错:

    clang -march=armv9-a+sme  -o stream stream.c -O0 -S -o -
    
  • 哪时候触发lazysave动作:caller有自己的za数据,callee本身也有自己的za数据,这时候需要做lazysave

  • SME ZA Lazy Save 概述

    SME (Scalable Matrix Extension) 的 ZA Lazy Save 是一种优化的 ZA 状态保存机制,用于函数调用时的上下文保护。

    • 触发条件 (AArch64SMEAttributes.h:106-109)

      bool requiresLazySave(const SMEAttrs &Callee) const {
      return hasZAState() && Callee.hasPrivateZAInterface() &&
      !(Callee.Bitmask & SME_ABI_Routine);
      }

      当满足以下条件时需要 lazy save:

      • 调用方有 ZA 状态 (hasZAState())
      • 被调用方使用私有 ZA 接口 (hasPrivateZAInterface()) - 不共享 ZA 状态
      • 被调用方不是 SME ABI 例程
    • 调用前的设置 (AArch64ISelLowering.cpp:9085-9110)

      在调用前执行以下操作:

      1. 获取 TPIDR2 对象(栈上的保存区域)
      2. 保存需要保存的 ZA 切片数量:使用 RDSVL 指令读取向量长度,存储到 TPIDR2 块偏移 8 的位置
      3. 设置 TPIDR2_EL0 寄存器:指向保存区域的基地址
      4. 发射优化备注 "sets up a lazy save for ZA"
    • 调用后的恢复 (AArch64ISelLowering.cpp:9640-9668)

      调用返回后:

      1. 获取当前的 TPIDR2_EL0 值
      2. 将 TPIDR2 块地址加载到 X0
      3. 使用 RESTORE_ZA 伪指令有条件地恢复 ZA 状态
      • 调用 __arm_tpidr2_restore 运行时支持例程
      1. 最后将 TPIDR2_EL0 重置为 0
    • 为什么叫 "Lazy"?

      关键点在于不是立即保存/恢复整个 ZA 状态,而是:

      • 调用前只设置好保存区域和元数据
      • 实际的保存由被调用方在需要修改 ZA 时才执行(通过 SME ABI 机制)
      • 恢复也是有条件的
    bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
    bool requiresLazySave(const SMEAttrs &Callee) const {
       return hasZAState() && Callee.hasPrivateZAInterface() &&
              !(Callee.Bitmask & SME_ABI_Routine);
     }
     bool hasZAState() const { return isNewZA() || sharesZA(); }
     bool sharesZA() const {
       StateValue State = decodeZAState(Bitmask);
       return State == StateValue::In || State == StateValue::Out ||
              State == StateValue::InOut || State == StateValue::Preserved;
     }
     bool hasAgnosticZAInterface() const { return Bitmask & ZA_State_Agnostic; }
     bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); }
     bool hasPrivateZAInterface() const {
       return !hasSharedZAInterface() && !hasAgnosticZAInterface();
     }
    
    // 情况 1:如果支持并需要“延迟保存 (Lazy Save)”
      // 这种模式下,CPU 只有在真正发生上下文切换时才去搬运巨大的 ZA 数据
      if (RequiresLazySave) {
        // 1. 获取函数维护的 TPIDR2 对象(包含堆栈帧索引 FrameIndex)
        const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
    
        // 2. 创建机器指针信息,用于 LLVM 的别名分析,标记该内存操作位于栈上
        MachinePointerInfo MPI =
            MachinePointerInfo::getStack(MF, TPIDR2.FrameIndex);
    
        // 3. 获取该 TPIDR2 内存控制块在当前栈帧中的起始地址
        SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
            TPIDR2.FrameIndex,
            DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
    
        // 4. 计算控制块中存储“ZA 切片数量 (Slices)”的地址偏移
        // 根据 SME 规范,该字段通常位于 TPIDR2 控制块偏移 8 字节处
        SDValue NumZaSaveSlicesAddr =
            DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr,
                        DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType()));
    
        // 5. 调用 RDSVL 指令获取当前硬件的流式向量长度 (SVL)
        // 1 SVL 等于当前 ZA 矩阵一列的元素个数(即需要保存的切片数)
        SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
                                              DAG.getConstant(1, DL, MVT::i32));
    
        // 6. 将切片数量写入内存控制块。硬件在执行延迟保存时会读取此值。
        Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, NumZaSaveSlicesAddr,
                                  MPI, MVT::i16);
    
        // 7. 【核心动作】设置系统寄存器 TPIDR2_EL0 指向我们的内存控制块
        // 这一步告诉硬件:“如果需要切出当前任务,请把 ZA 存到这个地址。”
        Chain = DAG.getNode(
            ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
            DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
            TPIDR2ObjAddr);
    
        // 8. 编译器诊断:如果开启了 -Rpass-analysis=sme,会告诉你此处生成了延迟保存
        OptimizationRemarkEmitter ORE(&MF.getFunction());
        ORE.emit([&]() {
          auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMELazySaveZA",
                                                       CLI.CB)
                          : OptimizationRemarkAnalysis("sme", "SMELazySaveZA",
                                                       &MF.getFunction());
          return DescribeCallsite(R) << " sets up a lazy save for ZA";
        });
    
      } 
      // 情况 2:如果不满足延迟保存条件,但必须保护 ZA
      else if (RequiresSaveAllZA) {
        // 断言:如果我们共享 ZA 接口,就不应该强制保存(因为逻辑上它应该是透明的)
        assert(!CalleeAttrs.hasSharedZAInterface() &&
               "Cannot share state that may not exist");
    
        // 强制执行全量保存。这会生成一系列 LDR/STR 指令手动搬运整个矩阵。
        // 性能开销极大,是最后的保底手段。
        Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
                                        /*IsSave=*/true);
      }
    
    // 情况 1:处理延迟保存(Lazy Save)的恢复逻辑
      if (RequiresLazySave) {
        // 1. 获取当前函数维护的 TPIDR2 对象(该对象指向内存中用于存放 ZA 数据的块)
        TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
        
        // 2. 获取 SME ABI 支持例程调用时需要保留的寄存器掩码
        // 这通常涉及调用底层的恢复函数,必须知道哪些寄存器在调用后依然有效
        SDValue RegMask = DAG.getRegisterMask(
            TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
            
        // 3. 获取运行时的支持例程符号 "__arm_tpidr2_restore"
        // 这个函数是底层库提供的,负责真正从内存加载数据到物理 ZA 阵列
        SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
            "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
            
        // 4. 读取系统寄存器 TPIDR2_EL0 的当前值(这是一个内建的读取操作)
        // TPIDR2_EL0 存储了指向延迟保存控制块的指针
        SDValue TPIDR2_EL0 = DAG.getNode(
            ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
            DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
    
        // 5. 将存放 ZA 数据的内存块地址(FrameIndex)拷贝到 X0 寄存器
        // 按照调用约定,__arm_tpidr2_restore 的第一个参数通常通过 X0 传递
        SDValue Glue;
        SDValue TPIDR2Block = DAG.getFrameIndex(
            TPIDR2.FrameIndex,
            DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
        Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
        
        // 6. 插入 RESTORE_ZA 伪指令节点
        // 该节点会触发实际的跳转,调用恢复例程,并依赖上面准备好的 X0 和 TPIDR2_EL0
        Result =
            DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
                        {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
                         RestoreRoutine, RegMask, Result.getValue(1)});
    
        // 7. 恢复完成后,必须将 TPIDR2_EL0 寄存器重置为 0
        // 这是为了告诉操作系统和硬件:当前的延迟保存快照已经失效(数据已回到物理寄存器)
        Result = DAG.getNode(
            ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
            DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
            DAG.getConstant(0, DL, MVT::i64));
        
        // 记录该对象的使用次数,用于后续的堆栈空间优化
        TPIDR2.Uses++;
        
      } else if (RequiresSaveAllZA) {
        // 情况 2:非延迟模式,直接全量恢复 ZA
        // 这种模式开销更高,因为它不通过系统例程判断是否需要恢复,而是强制执行
        Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
                                         /*IsSave=*/false);
      }
    

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions