Skip to content

Commit

Permalink
[AArch64] SME implementation for agnostic-ZA functions
Browse files Browse the repository at this point in the history
This implements the lowering of calls from agnostic-ZA
functions to non-agnostic-ZA functions, using the ABI routines
`__arm_sme_state_size`, `__arm_sme_save` and `__arm_sme_restore`.

This implements the proposal described in the following PRs:
* ARM-software/acle#336
* ARM-software/abi-aa#264
  • Loading branch information
sdesmalen-arm committed Sep 9, 2024
1 parent 9cae4ef commit 183a7ca
Show file tree
Hide file tree
Showing 11 changed files with 338 additions and 8 deletions.
5 changes: 5 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -3017,6 +3017,11 @@ let TargetPrefix = "aarch64" in {
def int_aarch64_sme_za_disable
: DefaultAttrsIntrinsic<[], [], [IntrNoMem, IntrHasSideEffects]>;

def int_aarch64_sme_save
: DefaultAttrsIntrinsic<[], [llvm_anyptr_ty], [IntrNoMem, IntrHasSideEffects]>;
def int_aarch64_sme_restore
: DefaultAttrsIntrinsic<[], [llvm_anyptr_ty], [IntrNoMem, IntrHasSideEffects]>;

// Clamp
//

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5195,7 +5195,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
SMEAttrs CallerAttrs(*FuncInfo.Fn);
if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
CallerAttrs.hasStreamingInterfaceOrBody() ||
CallerAttrs.hasStreamingCompatibleInterface())
CallerAttrs.hasStreamingCompatibleInterface() ||
CallerAttrs.hasAgnosticZAInterface())
return nullptr;
return new AArch64FastISel(FuncInfo, LibInfo);
}
134 changes: 132 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
#include "llvm/IR/Use.h"
#include "llvm/IR/Value.h"
#include "llvm/MC/MCRegisterInfo.h"
#include "llvm/MC/MCContext.h"
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CodeGen.h"
Expand Down Expand Up @@ -2561,6 +2562,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
break;
MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
MAKE_CASE(AArch64ISD::VG_SAVE)
MAKE_CASE(AArch64ISD::VG_RESTORE)
Expand Down Expand Up @@ -3146,6 +3149,42 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
return BB;
}

MachineBasicBlock *
AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
MachineBasicBlock *BB) const {
MachineFunction *MF = BB->getParent();
MachineFrameInfo &MFI = MF->getFrameInfo();
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
// TODO This function grows the stack with a subtraction, which doesn't work
// on Windows. Some refactoring to share the functionality in
// LowerWindowsDYNAMIC_STACKALLOC will be required once the Windows ABI
// supports SME
assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
"Lazy ZA save is not yet supported on Windows");

const TargetInstrInfo *TII = Subtarget->getInstrInfo();
if (FuncInfo->getSMESaveBufferUsed()) {
// Allocate a lazy-save buffer object of the size given, normally SVL * SVL
auto Size = MI.getOperand(1).getReg();
auto Dest = MI.getOperand(0).getReg();
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
.addReg(AArch64::SP)
.addReg(Size)
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
AArch64::SP)
.addReg(Dest);

// We have just allocated a variable sized object, tell this to PEI.
MFI.CreateVariableSizedObject(Align(16), nullptr);
} else
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
MI.getOperand(0).getReg());

BB->remove_instr(&MI);
return BB;
}

MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
MachineInstr &MI, MachineBasicBlock *BB) const {

Expand Down Expand Up @@ -3180,6 +3219,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
return EmitInitTPIDR2Object(MI, BB);
case AArch64::AllocateZABuffer:
return EmitAllocateZABuffer(MI, BB);
case AArch64::AllocateSMESaveBuffer:
return EmitAllocateSMESaveBuffer(MI, BB);
case AArch64::GetSMESaveSize: {
// If the buffer is used, emit a call to __arm_sme_state_size()
MachineFunction *MF = BB->getParent();
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
if (FuncInfo->getSMESaveBufferUsed()) {
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
.addExternalSymbol("__arm_sme_state_size")
.addReg(AArch64::X0, RegState::ImplicitDefine)
.addRegMask(TRI->getCallPreservedMask(
*MF, CallingConv::
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
MI.getOperand(0).getReg())
.addReg(AArch64::X0);
} else
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
MI.getOperand(0).getReg())
.addReg(AArch64::XZR);
BB->remove_instr(&MI);
return BB;
}
case AArch64::F128CSEL:
return EmitF128CSEL(MI, BB);
case TargetOpcode::STATEPOINT:
Expand Down Expand Up @@ -5645,6 +5709,28 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
Op->getOperand(0), // Chain
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
case Intrinsic::aarch64_sme_save:
case Intrinsic::aarch64_sme_restore: {
ArgListTy Args;
ArgListEntry Entry;
Entry.Ty = PointerType::getUnqual(*DAG.getContext());
Entry.Node = Op.getOperand(2);
Args.push_back(Entry);

SDValue Callee = DAG.getExternalSymbol(IntNo == Intrinsic::aarch64_sme_save
? "__arm_sme_save"
: "__arm_sme_restore",
getPointerTy(DAG.getDataLayout()));
auto *RetTy = Type::getVoidTy(*DAG.getContext());
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL)
.setChain(Op.getOperand(0))
.setLibCallee(
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1,
RetTy, Callee, std::move(Args));

return LowerCallTo(CLI).second;
}
}
}

Expand Down Expand Up @@ -7397,6 +7483,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
case CallingConv::AArch64_VectorCall:
case CallingConv::AArch64_SVE_VectorCall:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
return CC_AArch64_AAPCS;
case CallingConv::ARM64EC_Thunk_X64:
Expand Down Expand Up @@ -7858,6 +7945,30 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
} else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
// Call __arm_sme_state_size().
SDValue BufferSize = DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
DAG.getVTList(MVT::i64, MVT::Other), Chain);
Chain = BufferSize.getValue(1);

SDValue Buffer;
if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
Buffer =
DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
} else {
// Allocate space dynamically.
Buffer = DAG.getNode(
ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
{Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
MFI.CreateVariableSizedObject(Align(16), nullptr);
}

// Copy the value to a virtual register, and save that in FuncInfo.
Register BufferPtr =
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
FuncInfo->setSMESaveBufferAddr(BufferPtr);
Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
}

if (CallConv == CallingConv::PreserveNone) {
Expand Down Expand Up @@ -8146,6 +8257,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
CallerAttrs.hasStreamingBody())
return false;

Expand Down Expand Up @@ -8559,6 +8671,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
};

bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
bool RequiresSaveAllZA =
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
SDValue ZAStateBuffer;
if (RequiresLazySave) {
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
MachinePointerInfo MPI =
Expand All @@ -8585,6 +8700,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
&MF.getFunction());
return DescribeCallsite(R) << " sets up a lazy save for ZA";
});
} else if (RequiresSaveAllZA) {
Chain =
DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
DAG.getConstant(Intrinsic::aarch64_sme_save, DL, MVT::i32),
DAG.getCopyFromReg(
Chain, DL, FuncInfo->getSMESaveBufferAddr(), MVT::i64));
}

SDValue PStateSM;
Expand Down Expand Up @@ -9138,9 +9259,17 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64));
TPIDR2.Uses++;
} else if (RequiresSaveAllZA) {
Result = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
DAG.getConstant(Intrinsic::aarch64_sme_restore, DL, MVT::i32),
DAG.getCopyFromReg(Result, DL, FuncInfo->getSMESaveBufferAddr(),
MVT::i64));
FuncInfo->setSMESaveBufferUsed();
}

if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
RequiresSaveAllZA) {
for (unsigned I = 0; I < InVals.size(); ++I) {
// The smstart/smstop is chained as part of the call, but when the
// resulting chain is discarded (which happens when the call is not part
Expand Down Expand Up @@ -27819,7 +27948,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
auto CalleeAttrs = SMEAttrs(*Base);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresPreservingZT0(CalleeAttrs))
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
return true;
}
return false;
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,10 @@ enum NodeType : unsigned {
ALLOCATE_ZA_BUFFER,
INIT_TPIDR2OBJ,

// Needed for __arm_agnostic("sme_za_state")
GET_SME_SAVE_SIZE,
ALLOC_SME_SAVE_BUFFER,

// Asserts that a function argument (i32) is zero-extended to i8 by
// the caller
ASSERT_ZEXT_BOOL,
Expand Down Expand Up @@ -663,6 +667,8 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI,
MachineBasicBlock *BB) const;

MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr &MI,
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// on function entry to record the initial pstate of a function.
Register PStateSMReg = MCRegister::NoRegister;

// Holds a pointer to a buffer that is large enough to represent
// all SME ZA state and any additional state required by the
// __arm_sme_save/restore support routines.
Register SMESaveBufferAddr = MCRegister::NoRegister;

// true if SMESaveBufferAddr is used.
bool SMESaveBufferUsed = false;

// Has the PNReg used to build PTRUE instruction.
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
unsigned PredicateRegForFillSpill = 0;
Expand All @@ -247,6 +255,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
return PredicateRegForFillSpill;
}

Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };

unsigned getSMESaveBufferUsed() const { return SMESaveBufferUsed; };
void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };

Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };

Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ let usesCustomInserter = 1 in {
def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
}

// Nodes to allocate a save buffer for SME.
def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0,
[SDTCisInt<0>]>, [SDNPHasChain]>;
let usesCustomInserter = 1, Defs = [X0] in {
def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {}
}
def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>;

def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1,
[SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>;
let usesCustomInserter = 1, Defs = [SP] in {
def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {}
}
def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
(AllocateSMESaveBuffer $size)>;

//===----------------------------------------------------------------------===//
// Instruction naming conventions.
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 18 additions & 1 deletion llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,17 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
(cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
isSMEABIRoutineCall(cast<CallInst>(I))))
return true;

if (auto *CB = dyn_cast<CallBase>(&I)) {
SMEAttrs CallerAttrs(*CB->getCaller()),
CalleeAttrs(*CB->getCalledFunction());
// When trying to determine if we can inline callees, we must check
// that for agnostic-ZA functions, they don't call any functions
// that are not agnostic-ZA, as that would require inserting of
// save/restore code.
if (CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
return true;
}
}
}
return false;
Expand All @@ -255,7 +266,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,

if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresPreservingZT0(CalleeAttrs)) {
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
if (hasPossibleIncompatibleOps(Callee))
return false;
}

if (CalleeAttrs.hasAgnosticZAInterface()) {
if (hasPossibleIncompatibleOps(Callee))
return false;
}
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
isPreservesZT0())) &&
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");

assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) &&
"Function cannot have a shared-ZA interface and an agnostic-ZA "
"interface");
}

SMEAttrs::SMEAttrs(const CallBase &CB) {
Expand All @@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
Bitmask |= SMEAttrs::SM_Compatible;
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
FuncName == "__arm_sme_state_size")
Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
}

SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Expand All @@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= SM_Compatible;
if (Attrs.hasFnAttr("aarch64_pstate_sm_body"))
Bitmask |= SM_Body;
if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
Bitmask |= ZA_State_Agnostic;
if (Attrs.hasFnAttr("aarch64_in_za"))
Bitmask |= encodeZAState(StateValue::In);
if (Attrs.hasFnAttr("aarch64_out_za"))
Expand Down
Loading

0 comments on commit 183a7ca

Please sign in to comment.