From 183a7caef115fc20de605be2a8025284abf17aaf Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Wed, 4 Sep 2024 15:45:45 +0100 Subject: [PATCH] [AArch64] SME implementation for agnostic-ZA functions This implements the lowering of calls from agnostic-ZA functions to non-agnostic-ZA functions, using the ABI routines `__arm_sme_state_size`, `__arm_sme_save` and `__arm_sme_restore`. This implements the proposal described in the following PRs: * https://github.com/ARM-software/acle/pull/336 * https://github.com/ARM-software/abi-aa/pull/264 --- llvm/include/llvm/IR/IntrinsicsAArch64.td | 5 + llvm/lib/Target/AArch64/AArch64FastISel.cpp | 3 +- .../Target/AArch64/AArch64ISelLowering.cpp | 134 +++++++++++++++++- llvm/lib/Target/AArch64/AArch64ISelLowering.h | 6 + .../AArch64/AArch64MachineFunctionInfo.h | 14 ++ .../lib/Target/AArch64/AArch64SMEInstrInfo.td | 16 +++ .../AArch64/AArch64TargetTransformInfo.cpp | 19 ++- .../AArch64/Utils/AArch64SMEAttributes.cpp | 9 ++ .../AArch64/Utils/AArch64SMEAttributes.h | 18 ++- llvm/test/CodeGen/AArch64/sme-agnostic-za.ll | 98 +++++++++++++ .../AArch64/sme-disable-gisel-fisel.ll | 24 ++++ 11 files changed, 338 insertions(+), 8 deletions(-) create mode 100644 llvm/test/CodeGen/AArch64/sme-agnostic-za.ll diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td index 8ac1d67e162f70..ecd900b1c70d80 100644 --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -3017,6 +3017,11 @@ let TargetPrefix = "aarch64" in { def int_aarch64_sme_za_disable : DefaultAttrsIntrinsic<[], [], [IntrNoMem, IntrHasSideEffects]>; + def int_aarch64_sme_save + : DefaultAttrsIntrinsic<[], [llvm_anyptr_ty], [IntrNoMem, IntrHasSideEffects]>; + def int_aarch64_sme_restore + : DefaultAttrsIntrinsic<[], [llvm_anyptr_ty], [IntrNoMem, IntrHasSideEffects]>; + // Clamp // diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp index cbf38f2c57a35e..2edfeb854fd763 100644 --- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp +++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp @@ -5195,7 +5195,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo, SMEAttrs CallerAttrs(*FuncInfo.Fn); if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() || CallerAttrs.hasStreamingInterfaceOrBody() || - CallerAttrs.hasStreamingCompatibleInterface()) + CallerAttrs.hasStreamingCompatibleInterface() || + CallerAttrs.hasAgnosticZAInterface()) return nullptr; return new AArch64FastISel(FuncInfo, LibInfo); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index d1ddbfa300846b..060961b3748e9d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -76,6 +76,7 @@ #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" #include "llvm/MC/MCRegisterInfo.h" +#include "llvm/MC/MCContext.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CodeGen.h" @@ -2561,6 +2562,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { break; MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER) MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ) + MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE) + MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER) MAKE_CASE(AArch64ISD::COALESCER_BARRIER) MAKE_CASE(AArch64ISD::VG_SAVE) MAKE_CASE(AArch64ISD::VG_RESTORE) @@ -3146,6 +3149,42 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI, return BB; } +MachineBasicBlock * +AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI, + MachineBasicBlock *BB) const { + MachineFunction *MF = BB->getParent(); + MachineFrameInfo &MFI = MF->getFrameInfo(); + AArch64FunctionInfo *FuncInfo = MF->getInfo(); + // TODO This function grows the stack with a subtraction, which doesn't work + // on Windows. Some refactoring to share the functionality in + // LowerWindowsDYNAMIC_STACKALLOC will be required once the Windows ABI + // supports SME + assert(!MF->getSubtarget().isTargetWindows() && + "Lazy ZA save is not yet supported on Windows"); + + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + if (FuncInfo->getSMESaveBufferUsed()) { + // Allocate a lazy-save buffer object of the size given, normally SVL * SVL + auto Size = MI.getOperand(1).getReg(); + auto Dest = MI.getOperand(0).getReg(); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest) + .addReg(AArch64::SP) + .addReg(Size) + .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0)); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), + AArch64::SP) + .addReg(Dest); + + // We have just allocated a variable sized object, tell this to PEI. + MFI.CreateVariableSizedObject(Align(16), nullptr); + } else + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF), + MI.getOperand(0).getReg()); + + BB->remove_instr(&MI); + return BB; +} + MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( MachineInstr &MI, MachineBasicBlock *BB) const { @@ -3180,6 +3219,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( return EmitInitTPIDR2Object(MI, BB); case AArch64::AllocateZABuffer: return EmitAllocateZABuffer(MI, BB); + case AArch64::AllocateSMESaveBuffer: + return EmitAllocateSMESaveBuffer(MI, BB); + case AArch64::GetSMESaveSize: { + // If the buffer is used, emit a call to __arm_sme_state_size() + MachineFunction *MF = BB->getParent(); + AArch64FunctionInfo *FuncInfo = MF->getInfo(); + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + if (FuncInfo->getSMESaveBufferUsed()) { + const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) + .addExternalSymbol("__arm_sme_state_size") + .addReg(AArch64::X0, RegState::ImplicitDefine) + .addRegMask(TRI->getCallPreservedMask( + *MF, CallingConv:: + AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1)); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), + MI.getOperand(0).getReg()) + .addReg(AArch64::X0); + } else + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), + MI.getOperand(0).getReg()) + .addReg(AArch64::XZR); + BB->remove_instr(&MI); + return BB; + } case AArch64::F128CSEL: return EmitF128CSEL(MI, BB); case TargetOpcode::STATEPOINT: @@ -5645,6 +5709,28 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op, Op->getOperand(0), // Chain DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32), DAG.getConstant(AArch64SME::Always, DL, MVT::i64)); + case Intrinsic::aarch64_sme_save: + case Intrinsic::aarch64_sme_restore: { + ArgListTy Args; + ArgListEntry Entry; + Entry.Ty = PointerType::getUnqual(*DAG.getContext()); + Entry.Node = Op.getOperand(2); + Args.push_back(Entry); + + SDValue Callee = DAG.getExternalSymbol(IntNo == Intrinsic::aarch64_sme_save + ? "__arm_sme_save" + : "__arm_sme_restore", + getPointerTy(DAG.getDataLayout())); + auto *RetTy = Type::getVoidTy(*DAG.getContext()); + TargetLowering::CallLoweringInfo CLI(DAG); + CLI.setDebugLoc(DL) + .setChain(Op.getOperand(0)) + .setLibCallee( + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, + RetTy, Callee, std::move(Args)); + + return LowerCallTo(CLI).second; + } } } @@ -7397,6 +7483,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC, case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1: case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: return CC_AArch64_AAPCS; case CallingConv::ARM64EC_Thunk_X64: @@ -7858,6 +7945,30 @@ SDValue AArch64TargetLowering::LowerFormalArguments( Chain = DAG.getNode( AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other), {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)}); + } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) { + // Call __arm_sme_state_size(). + SDValue BufferSize = DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL, + DAG.getVTList(MVT::i64, MVT::Other), Chain); + Chain = BufferSize.getValue(1); + + SDValue Buffer; + if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) { + Buffer = + DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL, + DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize}); + } else { + // Allocate space dynamically. + Buffer = DAG.getNode( + ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other), + {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)}); + MFI.CreateVariableSizedObject(Align(16), nullptr); + } + + // Copy the value to a virtual register, and save that in FuncInfo. + Register BufferPtr = + MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass); + FuncInfo->setSMESaveBufferAddr(BufferPtr); + Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer); } if (CallConv == CallingConv::PreserveNone) { @@ -8146,6 +8257,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal); if (CallerAttrs.requiresSMChange(CalleeAttrs) || CallerAttrs.requiresLazySave(CalleeAttrs) || + CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) || CallerAttrs.hasStreamingBody()) return false; @@ -8559,6 +8671,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, }; bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs); + bool RequiresSaveAllZA = + CallerAttrs.requiresPreservingAllZAState(CalleeAttrs); + SDValue ZAStateBuffer; if (RequiresLazySave) { const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); MachinePointerInfo MPI = @@ -8585,6 +8700,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, &MF.getFunction()); return DescribeCallsite(R) << " sets up a lazy save for ZA"; }); + } else if (RequiresSaveAllZA) { + Chain = + DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, + DAG.getConstant(Intrinsic::aarch64_sme_save, DL, MVT::i32), + DAG.getCopyFromReg( + Chain, DL, FuncInfo->getSMESaveBufferAddr(), MVT::i64)); } SDValue PStateSM; @@ -9138,9 +9259,17 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), DAG.getConstant(0, DL, MVT::i64)); TPIDR2.Uses++; + } else if (RequiresSaveAllZA) { + Result = DAG.getNode( + ISD::INTRINSIC_VOID, DL, MVT::Other, Result, + DAG.getConstant(Intrinsic::aarch64_sme_restore, DL, MVT::i32), + DAG.getCopyFromReg(Result, DL, FuncInfo->getSMESaveBufferAddr(), + MVT::i64)); + FuncInfo->setSMESaveBufferUsed(); } - if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) { + if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 || + RequiresSaveAllZA) { for (unsigned I = 0; I < InVals.size(); ++I) { // The smstart/smstop is chained as part of the call, but when the // resulting chain is discarded (which happens when the call is not part @@ -27819,7 +27948,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const { auto CalleeAttrs = SMEAttrs(*Base); if (CallerAttrs.requiresSMChange(CalleeAttrs) || CallerAttrs.requiresLazySave(CalleeAttrs) || - CallerAttrs.requiresPreservingZT0(CalleeAttrs)) + CallerAttrs.requiresPreservingZT0(CalleeAttrs) || + CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) return true; } return false; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index f9d45b02d30e30..19b4e3916f0458 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -460,6 +460,10 @@ enum NodeType : unsigned { ALLOCATE_ZA_BUFFER, INIT_TPIDR2OBJ, + // Needed for __arm_agnostic("sme_za_state") + GET_SME_SAVE_SIZE, + ALLOC_SME_SAVE_BUFFER, + // Asserts that a function argument (i32) is zero-extended to i8 by // the caller ASSERT_ZEXT_BOOL, @@ -663,6 +667,8 @@ class AArch64TargetLowering : public TargetLowering { MachineBasicBlock *BB) const; MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI, MachineBasicBlock *BB) const; + MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI, + MachineBasicBlock *BB) const; MachineBasicBlock * EmitInstrWithCustomInserter(MachineInstr &MI, diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h index 72f110cebbdc8f..7e48ead8acb760 100644 --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -224,6 +224,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo { // on function entry to record the initial pstate of a function. Register PStateSMReg = MCRegister::NoRegister; + // Holds a pointer to a buffer that is large enough to represent + // all SME ZA state and any additional state required by the + // __arm_sme_save/restore support routines. + Register SMESaveBufferAddr = MCRegister::NoRegister; + + // true if SMESaveBufferAddr is used. + bool SMESaveBufferUsed = false; + // Has the PNReg used to build PTRUE instruction. // The PTRUE is used for the LD/ST of ZReg pairs in save and restore. unsigned PredicateRegForFillSpill = 0; @@ -247,6 +255,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo { return PredicateRegForFillSpill; } + Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; }; + void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; }; + + unsigned getSMESaveBufferUsed() const { return SMESaveBufferUsed; }; + void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; }; + Register getPStateSMReg() const { return PStateSMReg; }; void setPStateSMReg(Register Reg) { PStateSMReg = Reg; }; diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index ebe4121c944b1e..adb9779375298d 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -52,6 +52,22 @@ let usesCustomInserter = 1 in { def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {} } +// Nodes to allocate a save buffer for SME. +def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0, + [SDTCisInt<0>]>, [SDNPHasChain]>; +let usesCustomInserter = 1, Defs = [X0] in { + def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {} +} +def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>; + +def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1, + [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>; +let usesCustomInserter = 1, Defs = [SP] in { + def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {} +} +def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)), + (AllocateSMESaveBuffer $size)>; + //===----------------------------------------------------------------------===// // Instruction naming conventions. //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 58c267f1ce4bd6..e03e3066a7f650 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -234,6 +234,17 @@ static bool hasPossibleIncompatibleOps(const Function *F) { (cast(I).isInlineAsm() || isa(I) || isSMEABIRoutineCall(cast(I)))) return true; + + if (auto *CB = dyn_cast(&I)) { + SMEAttrs CallerAttrs(*CB->getCaller()), + CalleeAttrs(*CB->getCalledFunction()); + // When trying to determine if we can inline callees, we must check + // that for agnostic-ZA functions, they don't call any functions + // that are not agnostic-ZA, as that would require inserting of + // save/restore code. + if (CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) + return true; + } } } return false; @@ -255,7 +266,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, if (CallerAttrs.requiresLazySave(CalleeAttrs) || CallerAttrs.requiresSMChange(CalleeAttrs) || - CallerAttrs.requiresPreservingZT0(CalleeAttrs)) { + CallerAttrs.requiresPreservingZT0(CalleeAttrs) || + CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) { + if (hasPossibleIncompatibleOps(Callee)) + return false; + } + + if (CalleeAttrs.hasAgnosticZAInterface()) { if (hasPossibleIncompatibleOps(Callee)) return false; } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 015ca4cb92b25e..bf16acd7f8f7e1 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) { isPreservesZT0())) && "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', " "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive"); + + assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) && + "Function cannot have a shared-ZA interface and an agnostic-ZA " + "interface"); } SMEAttrs::SMEAttrs(const CallBase &CB) { @@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) { if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" || FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr") Bitmask |= SMEAttrs::SM_Compatible; + if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" || + FuncName == "__arm_sme_state_size") + Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; } SMEAttrs::SMEAttrs(const AttributeList &Attrs) { @@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { Bitmask |= SM_Compatible; if (Attrs.hasFnAttr("aarch64_pstate_sm_body")) Bitmask |= SM_Body; + if (Attrs.hasFnAttr("aarch64_za_state_agnostic")) + Bitmask |= ZA_State_Agnostic; if (Attrs.hasFnAttr("aarch64_in_za")) Bitmask |= encodeZAState(StateValue::In); if (Attrs.hasFnAttr("aarch64_out_za")) diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index 4c7c1c9b079538..816e7d56618d14 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -42,9 +42,10 @@ class SMEAttrs { SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible SM_Body = 1 << 2, // aarch64_pstate_sm_body SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves - ZA_Shift = 4, + ZA_State_Agnostic = 1 << 4, + ZA_Shift = 5, ZA_Mask = 0b111 << ZA_Shift, - ZT0_Shift = 7, + ZT0_Shift = 8, ZT0_Mask = 0b111 << ZT0_Shift }; @@ -96,8 +97,11 @@ class SMEAttrs { return State == StateValue::In || State == StateValue::Out || State == StateValue::InOut || State == StateValue::Preserved; } + bool hasAgnosticZAInterface() const { return Bitmask & ZA_State_Agnostic; } bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); } - bool hasPrivateZAInterface() const { return !hasSharedZAInterface(); } + bool hasPrivateZAInterface() const { + return !hasSharedZAInterface() && !hasAgnosticZAInterface(); + } bool hasZAState() const { return isNewZA() || sharesZA(); } bool requiresLazySave(const SMEAttrs &Callee) const { return hasZAState() && Callee.hasPrivateZAInterface() && @@ -128,7 +132,8 @@ class SMEAttrs { } bool hasZT0State() const { return isNewZT0() || sharesZT0(); } bool requiresPreservingZT0(const SMEAttrs &Callee) const { - return hasZT0State() && !Callee.sharesZT0(); + return hasZT0State() && !Callee.sharesZT0() && + !Callee.hasAgnosticZAInterface(); } bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const { return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() && @@ -137,6 +142,11 @@ class SMEAttrs { bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const { return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee); } + bool + requiresPreservingAllZAState(const SMEAttrs &Callee) const { + return hasAgnosticZAInterface() && !Callee.hasAgnosticZAInterface() && + !(Callee.Bitmask & SME_ABI_Routine); + } }; } // namespace llvm diff --git a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll new file mode 100644 index 00000000000000..e3d8c8d2620822 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll @@ -0,0 +1,98 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mattr=+sme2 < %s | FileCheck %s + +target triple = "aarch64" + +declare i64 @private_za_decl(i64) +declare i64 @agnostic_decl(i64) "aarch64_za_state_agnostic" + +; No calls. Test that no buffer is allocated. +define i64 @agnostic_caller_no_callees(ptr %ptr) nounwind "aarch64_za_state_agnostic" { +; CHECK-LABEL: agnostic_caller_no_callees: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr x0, [x0] +; CHECK-NEXT: ret + %v = load i64, ptr %ptr + ret i64 %v +} + +; Use regular DYNAMIC_STACKALLOC for allocation, prevents alloca from being removed entirely. +define i64 @agnostic_caller_no_callees_stackprobe(ptr %ptr) nounwind "aarch64_za_state_agnostic" "probe-stack"="inline-asm" { +; CHECK-LABEL: agnostic_caller_no_callees_stackprobe: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: ldr x0, [x0] +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret + %v = load i64, ptr %ptr + ret i64 %v +} + +; agnostic-ZA -> private-ZA +; +; Test that a buffer is allocated and that the appropriate save/restore calls are +; inserted for calls to non-agnostic functions and that the arg/result registers are +; preserved by the register allocator. +define i64 @agnostic_caller_private_za_callee(i64 %v) nounwind "aarch64_za_state_agnostic" { +; CHECK-LABEL: agnostic_caller_private_za_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill +; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: mov x8, x0 +; CHECK-NEXT: bl __arm_sme_state_size +; CHECK-NEXT: sub x19, sp, x0 +; CHECK-NEXT: mov sp, x19 +; CHECK-NEXT: mov x0, x19 +; CHECK-NEXT: bl __arm_sme_save +; CHECK-NEXT: mov x0, x8 +; CHECK-NEXT: bl private_za_decl +; CHECK-NEXT: mov x1, x0 +; CHECK-NEXT: mov x0, x19 +; CHECK-NEXT: bl __arm_sme_restore +; CHECK-NEXT: mov x0, x19 +; CHECK-NEXT: bl __arm_sme_save +; CHECK-NEXT: mov x0, x1 +; CHECK-NEXT: bl private_za_decl +; CHECK-NEXT: mov x1, x0 +; CHECK-NEXT: mov x0, x19 +; CHECK-NEXT: bl __arm_sme_restore +; CHECK-NEXT: mov x0, x1 +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload +; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload +; CHECK-NEXT: ret + %res = call i64 @private_za_decl(i64 %v) + %res2 = call i64 @private_za_decl(i64 %res) + ret i64 %res2 +} + +; agnostic-ZA -> agnostic-ZA +; +; Should not result in save/restore code. +define i64 @agnostic_caller_agnostic_callee(i64 %v) nounwind "aarch64_za_state_agnostic" { +; CHECK-LABEL: agnostic_caller_agnostic_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl agnostic_decl +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret + %res = call i64 @agnostic_decl(i64 %v) + ret i64 %res +} + +; shared-ZA -> agnostic-ZA +; +; Should not result in lazy-save or save of ZT0 +define i64 @shared_caller_agnostic_callee(i64 %v) nounwind "aarch64_inout_za" "aarch64_inout_zt0" { +; CHECK-LABEL: shared_caller_agnostic_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl agnostic_decl +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret + %res = call i64 @agnostic_decl(i64 %v) + ret i64 %res +} diff --git a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll index 42dba22d257089..d9dc2ad841f167 100644 --- a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll +++ b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll @@ -526,3 +526,27 @@ entry: %add = fadd double %call, 4.200000e+01 ret double %add; } + +define void @agnostic_za_function(ptr %ptr) nounwind "aarch64_za_state_agnostic" { +; CHECK-COMMON-LABEL: agnostic_za_function: +; CHECK-COMMON: // %bb.0: +; CHECK-COMMON-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill +; CHECK-COMMON-NEXT: stp x20, x19, [sp, #16] // 16-byte Folded Spill +; CHECK-COMMON-NEXT: mov x29, sp +; CHECK-COMMON-NEXT: mov x8, x0 +; CHECK-COMMON-NEXT: bl __arm_sme_state_size +; CHECK-COMMON-NEXT: sub x20, sp, x0 +; CHECK-COMMON-NEXT: mov sp, x20 +; CHECK-COMMON-NEXT: mov x0, x20 +; CHECK-COMMON-NEXT: bl __arm_sme_save +; CHECK-COMMON-NEXT: blr x8 +; CHECK-COMMON-NEXT: mov x0, x20 +; CHECK-COMMON-NEXT: bl __arm_sme_restore +; CHECK-COMMON-NEXT: mov sp, x29 +; CHECK-COMMON-NEXT: ldp x20, x19, [sp, #16] // 16-byte Folded Reload +; CHECK-COMMON-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload +; CHECK-COMMON-NEXT: ret + call void %ptr() + ret void +} +