|
30 | 30 | #include "SIInstrInfo.h"
|
31 | 31 | #include "SIMachineFunctionInfo.h"
|
32 | 32 | #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
|
33 |
| -#include "llvm/CodeGen/Analysis.h" |
34 | 33 | #include "llvm/CodeGen/CallingConvLower.h"
|
35 | 34 | #include "llvm/CodeGen/MachineFunction.h"
|
36 | 35 | #include "llvm/CodeGen/MachineRegisterInfo.h"
|
|
41 | 40 | #include "llvm/Support/KnownBits.h"
|
42 | 41 | using namespace llvm;
|
43 | 42 |
|
| 43 | +static bool allocateKernArg(unsigned ValNo, MVT ValVT, MVT LocVT, |
| 44 | + CCValAssign::LocInfo LocInfo, |
| 45 | + ISD::ArgFlagsTy ArgFlags, CCState &State) { |
| 46 | + MachineFunction &MF = State.getMachineFunction(); |
| 47 | + AMDGPUMachineFunction *MFI = MF.getInfo<AMDGPUMachineFunction>(); |
| 48 | + |
| 49 | + uint64_t Offset = MFI->allocateKernArg(LocVT.getStoreSize(), |
| 50 | + ArgFlags.getOrigAlign()); |
| 51 | + State.addLoc(CCValAssign::getCustomMem(ValNo, ValVT, Offset, LocVT, LocInfo)); |
| 52 | + return true; |
| 53 | +} |
| 54 | + |
44 | 55 | static bool allocateCCRegs(unsigned ValNo, MVT ValVT, MVT LocVT,
|
45 | 56 | CCValAssign::LocInfo LocInfo,
|
46 | 57 | ISD::ArgFlagsTy ArgFlags, CCState &State,
|
@@ -899,118 +910,74 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForReturn(CallingConv::ID CC,
|
899 | 910 | /// for each individual part is i8. We pass the memory type as LocVT to the
|
900 | 911 | /// calling convention analysis function and the register type (Ins[x].VT) as
|
901 | 912 | /// the ValVT.
|
902 |
| -void AMDGPUTargetLowering::analyzeFormalArgumentsCompute( |
903 |
| - CCState &State, |
904 |
| - const SmallVectorImpl<ISD::InputArg> &Ins) const { |
905 |
| - const MachineFunction &MF = State.getMachineFunction(); |
906 |
| - const Function &Fn = MF.getFunction(); |
907 |
| - LLVMContext &Ctx = Fn.getParent()->getContext(); |
908 |
| - const AMDGPUSubtarget &ST = AMDGPUSubtarget::get(MF); |
909 |
| - const unsigned ExplicitOffset = ST.getExplicitKernelArgOffset(Fn); |
910 |
| - |
911 |
| - unsigned MaxAlign = 1; |
912 |
| - uint64_t ExplicitArgOffset = 0; |
913 |
| - const DataLayout &DL = Fn.getParent()->getDataLayout(); |
914 |
| - |
915 |
| - unsigned InIndex = 0; |
916 |
| - |
917 |
| - for (const Argument &Arg : Fn.args()) { |
918 |
| - Type *BaseArgTy = Arg.getType(); |
919 |
| - unsigned Align = DL.getABITypeAlignment(BaseArgTy); |
920 |
| - MaxAlign = std::max(Align, MaxAlign); |
921 |
| - unsigned AllocSize = DL.getTypeAllocSize(BaseArgTy); |
922 |
| - |
923 |
| - uint64_t ArgOffset = alignTo(ExplicitArgOffset, Align) + ExplicitOffset; |
924 |
| - ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize; |
925 |
| - |
926 |
| - // We're basically throwing away everything passed into us and starting over |
927 |
| - // to get accurate in-memory offsets. The "PartOffset" is completely useless |
928 |
| - // to us as computed in Ins. |
929 |
| - // |
930 |
| - // We also need to figure out what type legalization is trying to do to get |
931 |
| - // the correct memory offsets. |
932 |
| - |
933 |
| - SmallVector<EVT, 16> ValueVTs; |
934 |
| - SmallVector<uint64_t, 16> Offsets; |
935 |
| - ComputeValueVTs(*this, DL, BaseArgTy, ValueVTs, &Offsets, ArgOffset); |
936 |
| - |
937 |
| - for (unsigned Value = 0, NumValues = ValueVTs.size(); |
938 |
| - Value != NumValues; ++Value) { |
939 |
| - uint64_t BasePartOffset = Offsets[Value]; |
940 |
| - |
941 |
| - EVT ArgVT = ValueVTs[Value]; |
942 |
| - EVT MemVT = ArgVT; |
943 |
| - MVT RegisterVT = |
944 |
| - getRegisterTypeForCallingConv(Ctx, ArgVT); |
945 |
| - unsigned NumRegs = |
946 |
| - getNumRegistersForCallingConv(Ctx, ArgVT); |
947 |
| - |
948 |
| - if (!Subtarget->isAmdHsaOS() && |
949 |
| - (ArgVT == MVT::i16 || ArgVT == MVT::i8 || ArgVT == MVT::f16)) { |
950 |
| - // The ABI says the caller will extend these values to 32-bits. |
951 |
| - MemVT = ArgVT.isInteger() ? MVT::i32 : MVT::f32; |
952 |
| - } else if (NumRegs == 1) { |
953 |
| - // This argument is not split, so the IR type is the memory type. |
954 |
| - if (ArgVT.isExtended()) { |
955 |
| - // We have an extended type, like i24, so we should just use the |
956 |
| - // register type. |
957 |
| - MemVT = RegisterVT; |
958 |
| - } else { |
959 |
| - MemVT = ArgVT; |
960 |
| - } |
961 |
| - } else if (ArgVT.isVector() && RegisterVT.isVector() && |
962 |
| - ArgVT.getScalarType() == RegisterVT.getScalarType()) { |
963 |
| - assert(ArgVT.getVectorNumElements() > RegisterVT.getVectorNumElements()); |
964 |
| - // We have a vector value which has been split into a vector with |
965 |
| - // the same scalar type, but fewer elements. This should handle |
966 |
| - // all the floating-point vector types. |
967 |
| - MemVT = RegisterVT; |
968 |
| - } else if (ArgVT.isVector() && |
969 |
| - ArgVT.getVectorNumElements() == NumRegs) { |
970 |
| - // This arg has been split so that each element is stored in a separate |
971 |
| - // register. |
972 |
| - MemVT = ArgVT.getScalarType(); |
973 |
| - } else if (ArgVT.isExtended()) { |
974 |
| - // We have an extended type, like i65. |
975 |
| - MemVT = RegisterVT; |
| 913 | +void AMDGPUTargetLowering::analyzeFormalArgumentsCompute(CCState &State, |
| 914 | + const SmallVectorImpl<ISD::InputArg> &Ins) const { |
| 915 | + for (unsigned i = 0, e = Ins.size(); i != e; ++i) { |
| 916 | + const ISD::InputArg &In = Ins[i]; |
| 917 | + EVT MemVT; |
| 918 | + |
| 919 | + unsigned NumRegs = getNumRegisters(State.getContext(), In.ArgVT); |
| 920 | + |
| 921 | + if (!Subtarget->isAmdHsaOS() && |
| 922 | + (In.ArgVT == MVT::i16 || In.ArgVT == MVT::i8 || In.ArgVT == MVT::f16)) { |
| 923 | + // The ABI says the caller will extend these values to 32-bits. |
| 924 | + MemVT = In.ArgVT.isInteger() ? MVT::i32 : MVT::f32; |
| 925 | + } else if (NumRegs == 1) { |
| 926 | + // This argument is not split, so the IR type is the memory type. |
| 927 | + assert(!In.Flags.isSplit()); |
| 928 | + if (In.ArgVT.isExtended()) { |
| 929 | + // We have an extended type, like i24, so we should just use the register type |
| 930 | + MemVT = In.VT; |
976 | 931 | } else {
|
977 |
| - unsigned MemoryBits = ArgVT.getStoreSizeInBits() / NumRegs; |
978 |
| - assert(ArgVT.getStoreSizeInBits() % NumRegs == 0); |
979 |
| - if (RegisterVT.isInteger()) { |
980 |
| - MemVT = EVT::getIntegerVT(State.getContext(), MemoryBits); |
981 |
| - } else if (RegisterVT.isVector()) { |
982 |
| - assert(!RegisterVT.getScalarType().isFloatingPoint()); |
983 |
| - unsigned NumElements = RegisterVT.getVectorNumElements(); |
984 |
| - assert(MemoryBits % NumElements == 0); |
985 |
| - // This vector type has been split into another vector type with |
986 |
| - // a different elements size. |
987 |
| - EVT ScalarVT = EVT::getIntegerVT(State.getContext(), |
988 |
| - MemoryBits / NumElements); |
989 |
| - MemVT = EVT::getVectorVT(State.getContext(), ScalarVT, NumElements); |
990 |
| - } else { |
991 |
| - llvm_unreachable("cannot deduce memory type."); |
992 |
| - } |
| 932 | + MemVT = In.ArgVT; |
993 | 933 | }
|
994 |
| - |
995 |
| - // Convert one element vectors to scalar. |
996 |
| - if (MemVT.isVector() && MemVT.getVectorNumElements() == 1) |
997 |
| - MemVT = MemVT.getScalarType(); |
998 |
| - |
999 |
| - if (MemVT.isExtended()) { |
1000 |
| - // This should really only happen if we have vec3 arguments |
1001 |
| - assert(MemVT.isVector() && MemVT.getVectorNumElements() == 3); |
1002 |
| - MemVT = MemVT.getPow2VectorType(State.getContext()); |
| 934 | + } else if (In.ArgVT.isVector() && In.VT.isVector() && |
| 935 | + In.ArgVT.getScalarType() == In.VT.getScalarType()) { |
| 936 | + assert(In.ArgVT.getVectorNumElements() > In.VT.getVectorNumElements()); |
| 937 | + // We have a vector value which has been split into a vector with |
| 938 | + // the same scalar type, but fewer elements. This should handle |
| 939 | + // all the floating-point vector types. |
| 940 | + MemVT = In.VT; |
| 941 | + } else if (In.ArgVT.isVector() && |
| 942 | + In.ArgVT.getVectorNumElements() == NumRegs) { |
| 943 | + // This arg has been split so that each element is stored in a separate |
| 944 | + // register. |
| 945 | + MemVT = In.ArgVT.getScalarType(); |
| 946 | + } else if (In.ArgVT.isExtended()) { |
| 947 | + // We have an extended type, like i65. |
| 948 | + MemVT = In.VT; |
| 949 | + } else { |
| 950 | + unsigned MemoryBits = In.ArgVT.getStoreSizeInBits() / NumRegs; |
| 951 | + assert(In.ArgVT.getStoreSizeInBits() % NumRegs == 0); |
| 952 | + if (In.VT.isInteger()) { |
| 953 | + MemVT = EVT::getIntegerVT(State.getContext(), MemoryBits); |
| 954 | + } else if (In.VT.isVector()) { |
| 955 | + assert(!In.VT.getScalarType().isFloatingPoint()); |
| 956 | + unsigned NumElements = In.VT.getVectorNumElements(); |
| 957 | + assert(MemoryBits % NumElements == 0); |
| 958 | + // This vector type has been split into another vector type with |
| 959 | + // a different elements size. |
| 960 | + EVT ScalarVT = EVT::getIntegerVT(State.getContext(), |
| 961 | + MemoryBits / NumElements); |
| 962 | + MemVT = EVT::getVectorVT(State.getContext(), ScalarVT, NumElements); |
| 963 | + } else { |
| 964 | + llvm_unreachable("cannot deduce memory type."); |
1003 | 965 | }
|
| 966 | + } |
1004 | 967 |
|
1005 |
| - unsigned PartOffset = 0; |
1006 |
| - for (unsigned i = 0; i != NumRegs; ++i) { |
1007 |
| - State.addLoc(CCValAssign::getCustomMem(InIndex++, RegisterVT, |
1008 |
| - BasePartOffset + PartOffset, |
1009 |
| - MemVT.getSimpleVT(), |
1010 |
| - CCValAssign::Full)); |
1011 |
| - PartOffset += MemVT.getStoreSize(); |
1012 |
| - } |
| 968 | + // Convert one element vectors to scalar. |
| 969 | + if (MemVT.isVector() && MemVT.getVectorNumElements() == 1) |
| 970 | + MemVT = MemVT.getScalarType(); |
| 971 | + |
| 972 | + if (MemVT.isExtended()) { |
| 973 | + // This should really only happen if we have vec3 arguments |
| 974 | + assert(MemVT.isVector() && MemVT.getVectorNumElements() == 3); |
| 975 | + MemVT = MemVT.getPow2VectorType(State.getContext()); |
1013 | 976 | }
|
| 977 | + |
| 978 | + assert(MemVT.isSimple()); |
| 979 | + allocateKernArg(i, In.VT, MemVT.getSimpleVT(), CCValAssign::Full, In.Flags, |
| 980 | + State); |
1014 | 981 | }
|
1015 | 982 | }
|
1016 | 983 |
|
|
0 commit comments