summaryrefslogtreecommitdiff
path: root/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
diff options
context:
space:
mode:
authorArtem Belevich <tra@google.com>2017-02-23 22:38:24 +0000
committerArtem Belevich <tra@google.com>2017-02-23 22:38:24 +0000
commit6bc216ccf6a7fc8e9f500fb44b12f045995b4c3d (patch)
tree84463412d5bf8c030f398172005f3f7b7afd8205 /lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
parenta328146a758bf6d3f25429113bfee0a6575be284 (diff)
[NVPTX] Added support for .f16x2 instructions.
This patch enables support for .f16x2 operations. Added new register type Float16x2. Added support for .f16x2 instructions. Added handling of vectorized loads/stores of v2f16 values. Differential Revision: https://reviews.llvm.org/D30057 Differential Revision: https://reviews.llvm.org/D30310 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@296032 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp')
-rw-r--r--lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp346
1 files changed, 320 insertions, 26 deletions
diff --git a/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 2aef67b9caf..7da621ccdc3 100644
--- a/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -84,6 +84,14 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
if (tryStore(N))
return;
break;
+ case ISD::EXTRACT_VECTOR_ELT:
+ if (tryEXTRACT_VECTOR_ELEMENT(N))
+ return;
+ break;
+ case NVPTXISD::SETP_F16X2:
+ SelectSETP_F16X2(N);
+ return;
+
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
if (tryLoadVector(N))
@@ -516,6 +524,127 @@ bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) {
return true;
}
+// Map ISD:CONDCODE value to appropriate CmpMode expected by
+// NVPTXInstPrinter::printCmpMode()
+static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
+ using NVPTX::PTXCmpMode::CmpMode;
+ unsigned PTXCmpMode = [](ISD::CondCode CC) {
+ switch (CC) {
+ default:
+ llvm_unreachable("Unexpected condition code.");
+ case ISD::SETOEQ:
+ return CmpMode::EQ;
+ case ISD::SETOGT:
+ return CmpMode::GT;
+ case ISD::SETOGE:
+ return CmpMode::GE;
+ case ISD::SETOLT:
+ return CmpMode::LT;
+ case ISD::SETOLE:
+ return CmpMode::LE;
+ case ISD::SETONE:
+ return CmpMode::NE;
+ case ISD::SETO:
+ return CmpMode::NUM;
+ case ISD::SETUO:
+ return CmpMode::NotANumber;
+ case ISD::SETUEQ:
+ return CmpMode::EQU;
+ case ISD::SETUGT:
+ return CmpMode::GTU;
+ case ISD::SETUGE:
+ return CmpMode::GEU;
+ case ISD::SETULT:
+ return CmpMode::LTU;
+ case ISD::SETULE:
+ return CmpMode::LEU;
+ case ISD::SETUNE:
+ return CmpMode::NEU;
+ case ISD::SETEQ:
+ return CmpMode::EQ;
+ case ISD::SETGT:
+ return CmpMode::GT;
+ case ISD::SETGE:
+ return CmpMode::GE;
+ case ISD::SETLT:
+ return CmpMode::LT;
+ case ISD::SETLE:
+ return CmpMode::LE;
+ case ISD::SETNE:
+ return CmpMode::NE;
+ }
+ }(CondCode.get());
+
+ if (FTZ)
+ PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;
+
+ return PTXCmpMode;
+}
+
+bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
+ unsigned PTXCmpMode =
+ getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
+ SDLoc DL(N);
+ SDNode *SetP = CurDAG->getMachineNode(
+ NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
+ N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
+ ReplaceNode(N, SetP);
+ return true;
+}
+
+// Find all instances of extract_vector_elt that use this v2f16 vector
+// and coalesce them into a scattering move instruction.
+bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
+ SDValue Vector = N->getOperand(0);
+
+ // We only care about f16x2 as it's the only real vector type we
+ // need to deal with.
+ if (Vector.getSimpleValueType() != MVT::v2f16)
+ return false;
+
+ // Find and record all uses of this vector that extract element 0 or 1.
+ SmallVector<SDNode *, 4> E0, E1;
+ for (const auto &U : Vector.getNode()->uses()) {
+ if (U->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+ continue;
+ if (U->getOperand(0) != Vector)
+ continue;
+ if (const ConstantSDNode *IdxConst =
+ dyn_cast<ConstantSDNode>(U->getOperand(1))) {
+ if (IdxConst->getZExtValue() == 0)
+ E0.push_back(U);
+ else if (IdxConst->getZExtValue() == 1)
+ E1.push_back(U);
+ else
+ llvm_unreachable("Invalid vector index.");
+ }
+ }
+
+ // There's no point scattering f16x2 if we only ever access one
+ // element of it.
+ if (E0.empty() || E1.empty())
+ return false;
+
+ unsigned Op = NVPTX::SplitF16x2;
+ // If the vector has been BITCAST'ed from i32, we can use original
+ // value directly and avoid register-to-register move.
+ SDValue Source = Vector;
+ if (Vector->getOpcode() == ISD::BITCAST) {
+ Op = NVPTX::SplitI32toF16x2;
+ Source = Vector->getOperand(0);
+ }
+ // Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
+ // into f16,f16 SplitF16x2(V)
+ SDNode *ScatterOp =
+ CurDAG->getMachineNode(Op, SDLoc(N), MVT::f16, MVT::f16, Source);
+ for (auto *Node : E0)
+ ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
+ for (auto *Node : E1)
+ ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 1));
+
+ return true;
+}
+
static unsigned int getCodeAddrSpace(MemSDNode *N) {
const Value *Src = N->getMemOperand()->getValue();
@@ -689,29 +818,26 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
codeAddrSpace != NVPTX::PTXLdStInstCode::GENERIC)
isVolatile = false;
- // Vector Setting
- MVT SimpleVT = LoadedVT.getSimpleVT();
- unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
- if (SimpleVT.isVector()) {
- unsigned num = SimpleVT.getVectorNumElements();
- if (num == 2)
- vecType = NVPTX::PTXLdStInstCode::V2;
- else if (num == 4)
- vecType = NVPTX::PTXLdStInstCode::V4;
- else
- return false;
- }
-
// Type Setting: fromType + fromTypeWidth
//
// Sign : ISD::SEXTLOAD
// Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
// type is integer
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
+ MVT SimpleVT = LoadedVT.getSimpleVT();
MVT ScalarVT = SimpleVT.getScalarType();
// Read at least 8 bits (predicates are stored as 8-bit values)
unsigned fromTypeWidth = std::max(8U, ScalarVT.getSizeInBits());
unsigned int fromType;
+
+ // Vector Setting
+ unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
+ if (SimpleVT.isVector()) {
+ assert(LoadedVT == MVT::v2f16 && "Unexpected vector type");
+ // v2f16 is loaded using ld.b32
+ fromTypeWidth = 32;
+ }
+
if ((LD->getExtensionType() == ISD::SEXTLOAD))
fromType = NVPTX::PTXLdStInstCode::Signed;
else if (ScalarVT.isFloatingPoint())
@@ -746,6 +872,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::LD_f16_avar;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::LD_f16x2_avar;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_avar;
break;
@@ -777,6 +906,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::LD_f16_asi;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::LD_f16x2_asi;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_asi;
break;
@@ -809,6 +941,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::LD_f16_ari_64;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::LD_f16x2_ari_64;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_ari_64;
break;
@@ -835,6 +970,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::LD_f16_ari;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::LD_f16x2_ari;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_ari;
break;
@@ -867,6 +1005,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::LD_f16_areg_64;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::LD_f16x2_areg_64;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_areg_64;
break;
@@ -893,6 +1034,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::LD_f16_areg;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::LD_f16x2_areg;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_areg;
break;
@@ -968,7 +1112,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
if (ExtensionType == ISD::SEXTLOAD)
FromType = NVPTX::PTXLdStInstCode::Signed;
else if (ScalarVT.isFloatingPoint())
- FromType = NVPTX::PTXLdStInstCode::Float;
+ FromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
+ : NVPTX::PTXLdStInstCode::Float;
else
FromType = NVPTX::PTXLdStInstCode::Unsigned;
@@ -987,6 +1132,16 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
EVT EltVT = N->getValueType(0);
+ // v8f16 is a special case. PTX doesn't have ld.v8.f16
+ // instruction. Instead, we split the vector into v2f16 chunks and
+ // load them with ld.v4.b32.
+ if (EltVT == MVT::v2f16) {
+ assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
+ EltVT = MVT::i32;
+ FromType = NVPTX::PTXLdStInstCode::Untyped;
+ FromTypeWidth = 32;
+ }
+
if (SelectDirectAddr(Op1, Addr)) {
switch (N->getOpcode()) {
default:
@@ -1007,6 +1162,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LDV_i64_v2_avar;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v2_avar;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v2_avar;
break;
@@ -1028,6 +1186,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::LDV_i32_v4_avar;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v4_avar;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v4_avar;
break;
@@ -1060,6 +1221,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LDV_i64_v2_asi;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v2_asi;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v2_asi;
break;
@@ -1081,6 +1245,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::LDV_i32_v4_asi;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v4_asi;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v4_asi;
break;
@@ -1114,6 +1281,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LDV_i64_v2_ari_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v2_ari_64;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v2_ari_64;
break;
@@ -1135,6 +1305,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::LDV_i32_v4_ari_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v4_ari_64;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v4_ari_64;
break;
@@ -1161,6 +1334,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LDV_i64_v2_ari;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v2_ari;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v2_ari;
break;
@@ -1182,6 +1358,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::LDV_i32_v4_ari;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v4_ari;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v4_ari;
break;
@@ -1216,6 +1395,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LDV_i64_v2_areg_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v2_areg_64;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v2_areg_64;
break;
@@ -1237,6 +1419,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::LDV_i32_v4_areg_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v4_areg_64;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v4_areg_64;
break;
@@ -1263,6 +1448,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::LDV_i64_v2_areg;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v2_areg;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v2_areg;
break;
@@ -1284,6 +1472,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::LDV_i32_v4_areg;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LDV_f16_v4_areg;
+ break;
case MVT::f32:
Opcode = NVPTX::LDV_f32_v4_areg;
break;
@@ -2151,21 +2342,18 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
// Vector Setting
MVT SimpleVT = StoreVT.getSimpleVT();
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
- if (SimpleVT.isVector()) {
- unsigned num = SimpleVT.getVectorNumElements();
- if (num == 2)
- vecType = NVPTX::PTXLdStInstCode::V2;
- else if (num == 4)
- vecType = NVPTX::PTXLdStInstCode::V4;
- else
- return false;
- }
// Type Setting: toType + toTypeWidth
// - for integer type, always use 'u'
//
MVT ScalarVT = SimpleVT.getScalarType();
unsigned toTypeWidth = ScalarVT.getSizeInBits();
+ if (SimpleVT.isVector()) {
+ assert(StoreVT == MVT::v2f16 && "Unexpected vector type");
+ // v2f16 is stored using st.b32
+ toTypeWidth = 32;
+ }
+
unsigned int toType;
if (ScalarVT.isFloatingPoint())
// f16 uses .b16 as its storage type.
@@ -2200,6 +2388,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::ST_f16_avar;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::ST_f16x2_avar;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_avar;
break;
@@ -2232,6 +2423,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::ST_f16_asi;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::ST_f16x2_asi;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_asi;
break;
@@ -2265,6 +2459,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::ST_f16_ari_64;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::ST_f16x2_ari_64;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_ari_64;
break;
@@ -2291,6 +2488,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::ST_f16_ari;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::ST_f16x2_ari;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_ari;
break;
@@ -2324,6 +2524,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::ST_f16_areg_64;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::ST_f16x2_areg_64;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_areg_64;
break;
@@ -2350,6 +2553,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::ST_f16_areg;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::ST_f16x2_areg;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_areg;
break;
@@ -2411,7 +2617,8 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
unsigned ToType;
if (ScalarVT.isFloatingPoint())
- ToType = NVPTX::PTXLdStInstCode::Float;
+ ToType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
+ : NVPTX::PTXLdStInstCode::Float;
else
ToType = NVPTX::PTXLdStInstCode::Unsigned;
@@ -2438,6 +2645,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
return false;
}
+ // v8f16 is a special case. PTX doesn't have st.v8.f16
+ // instruction. Instead, we split the vector into v2f16 chunks and
+ // store them with st.v4.b32.
+ if (EltVT == MVT::v2f16) {
+ assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
+ EltVT = MVT::i32;
+ ToType = NVPTX::PTXLdStInstCode::Untyped;
+ ToTypeWidth = 32;
+ }
+
StOps.push_back(getI32Imm(IsVolatile, DL));
StOps.push_back(getI32Imm(CodeAddrSpace, DL));
StOps.push_back(getI32Imm(VecType, DL));
@@ -2464,6 +2681,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::STV_i64_v2_avar;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v2_avar;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v2_avar;
break;
@@ -2513,6 +2733,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::STV_i64_v2_asi;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v2_asi;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v2_asi;
break;
@@ -2534,6 +2757,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::STV_i32_v4_asi;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v4_asi;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v4_asi;
break;
@@ -2564,6 +2790,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::STV_i64_v2_ari_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v2_ari_64;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v2_ari_64;
break;
@@ -2585,6 +2814,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::STV_i32_v4_ari_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v4_ari_64;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v4_ari_64;
break;
@@ -2611,6 +2843,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::STV_i64_v2_ari;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v2_ari;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v2_ari;
break;
@@ -2632,6 +2867,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::STV_i32_v4_ari;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v4_ari;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v4_ari;
break;
@@ -2662,6 +2900,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::STV_i64_v2_areg_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v2_areg_64;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v2_areg_64;
break;
@@ -2683,6 +2924,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::STV_i32_v4_areg_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v4_areg_64;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v4_areg_64;
break;
@@ -2709,6 +2953,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::STV_i64_v2_areg;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v2_areg;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v2_areg;
break;
@@ -2730,6 +2977,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::STV_i32_v4_areg;
break;
+ case MVT::f16:
+ Opcode = NVPTX::STV_f16_v4_areg;
+ break;
case MVT::f32:
Opcode = NVPTX::STV_f32_v4_areg;
break;
@@ -2804,6 +3054,9 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
case MVT::f16:
Opc = NVPTX::LoadParamMemF16;
break;
+ case MVT::v2f16:
+ Opc = NVPTX::LoadParamMemF16x2;
+ break;
case MVT::f32:
Opc = NVPTX::LoadParamMemF32;
break;
@@ -2831,6 +3084,12 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
case MVT::i64:
Opc = NVPTX::LoadParamMemV2I64;
break;
+ case MVT::f16:
+ Opc = NVPTX::LoadParamMemV2F16;
+ break;
+ case MVT::v2f16:
+ Opc = NVPTX::LoadParamMemV2F16x2;
+ break;
case MVT::f32:
Opc = NVPTX::LoadParamMemV2F32;
break;
@@ -2855,6 +3114,12 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
case MVT::i32:
Opc = NVPTX::LoadParamMemV4I32;
break;
+ case MVT::f16:
+ Opc = NVPTX::LoadParamMemV4F16;
+ break;
+ case MVT::v2f16:
+ Opc = NVPTX::LoadParamMemV4F16x2;
+ break;
case MVT::f32:
Opc = NVPTX::LoadParamMemV4F32;
break;
@@ -2942,6 +3207,9 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::StoreRetvalF16;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::StoreRetvalF16x2;
+ break;
case MVT::f32:
Opcode = NVPTX::StoreRetvalF32;
break;
@@ -2969,6 +3237,12 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::StoreRetvalV2I64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::StoreRetvalV2F16;
+ break;
+ case MVT::v2f16:
+ Opcode = NVPTX::StoreRetvalV2F16x2;
+ break;
case MVT::f32:
Opcode = NVPTX::StoreRetvalV2F32;
break;
@@ -2993,6 +3267,12 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::StoreRetvalV4I32;
break;
+ case MVT::f16:
+ Opcode = NVPTX::StoreRetvalV4F16;
+ break;
+ case MVT::v2f16:
+ Opcode = NVPTX::StoreRetvalV4F16x2;
+ break;
case MVT::f32:
Opcode = NVPTX::StoreRetvalV4F32;
break;
@@ -3000,8 +3280,7 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
break;
}
- SDNode *Ret =
- CurDAG->getMachineNode(Opcode, DL, MVT::Other, Ops);
+ SDNode *Ret = CurDAG->getMachineNode(Opcode, DL, MVT::Other, Ops);
MachineSDNode::mmo_iterator MemRefs0 = MF->allocateMemRefsArray(1);
MemRefs0[0] = cast<MemSDNode>(N)->getMemOperand();
cast<MachineSDNode>(Ret)->setMemRefs(MemRefs0, MemRefs0 + 1);
@@ -3078,6 +3357,9 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
case MVT::f16:
Opcode = NVPTX::StoreParamF16;
break;
+ case MVT::v2f16:
+ Opcode = NVPTX::StoreParamF16x2;
+ break;
case MVT::f32:
Opcode = NVPTX::StoreParamF32;
break;
@@ -3105,6 +3387,12 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
case MVT::i64:
Opcode = NVPTX::StoreParamV2I64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::StoreParamV2F16;
+ break;
+ case MVT::v2f16:
+ Opcode = NVPTX::StoreParamV2F16x2;
+ break;
case MVT::f32:
Opcode = NVPTX::StoreParamV2F32;
break;
@@ -3129,6 +3417,12 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
case MVT::i32:
Opcode = NVPTX::StoreParamV4I32;
break;
+ case MVT::f16:
+ Opcode = NVPTX::StoreParamV4F16;
+ break;
+ case MVT::v2f16:
+ Opcode = NVPTX::StoreParamV4F16x2;
+ break;
case MVT::f32:
Opcode = NVPTX::StoreParamV4F32;
break;