diff options
author | Artem Belevich <tra@google.com> | 2017-02-23 22:38:24 +0000 |
---|---|---|
committer | Artem Belevich <tra@google.com> | 2017-02-23 22:38:24 +0000 |
commit | 6bc216ccf6a7fc8e9f500fb44b12f045995b4c3d (patch) | |
tree | 84463412d5bf8c030f398172005f3f7b7afd8205 /lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | |
parent | a328146a758bf6d3f25429113bfee0a6575be284 (diff) |
[NVPTX] Added support for .f16x2 instructions.
This patch enables support for .f16x2 operations.
Added new register type Float16x2.
Added support for .f16x2 instructions.
Added handling of vectorized loads/stores of v2f16 values.
Differential Revision: https://reviews.llvm.org/D30057
Differential Revision: https://reviews.llvm.org/D30310
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@296032 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp')
-rw-r--r-- | lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 346 |
1 files changed, 320 insertions, 26 deletions
diff --git a/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 2aef67b9caf..7da621ccdc3 100644 --- a/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -84,6 +84,14 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) { if (tryStore(N)) return; break; + case ISD::EXTRACT_VECTOR_ELT: + if (tryEXTRACT_VECTOR_ELEMENT(N)) + return; + break; + case NVPTXISD::SETP_F16X2: + SelectSETP_F16X2(N); + return; + case NVPTXISD::LoadV2: case NVPTXISD::LoadV4: if (tryLoadVector(N)) @@ -516,6 +524,127 @@ bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) { return true; } +// Map ISD:CONDCODE value to appropriate CmpMode expected by +// NVPTXInstPrinter::printCmpMode() +static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) { + using NVPTX::PTXCmpMode::CmpMode; + unsigned PTXCmpMode = [](ISD::CondCode CC) { + switch (CC) { + default: + llvm_unreachable("Unexpected condition code."); + case ISD::SETOEQ: + return CmpMode::EQ; + case ISD::SETOGT: + return CmpMode::GT; + case ISD::SETOGE: + return CmpMode::GE; + case ISD::SETOLT: + return CmpMode::LT; + case ISD::SETOLE: + return CmpMode::LE; + case ISD::SETONE: + return CmpMode::NE; + case ISD::SETO: + return CmpMode::NUM; + case ISD::SETUO: + return CmpMode::NotANumber; + case ISD::SETUEQ: + return CmpMode::EQU; + case ISD::SETUGT: + return CmpMode::GTU; + case ISD::SETUGE: + return CmpMode::GEU; + case ISD::SETULT: + return CmpMode::LTU; + case ISD::SETULE: + return CmpMode::LEU; + case ISD::SETUNE: + return CmpMode::NEU; + case ISD::SETEQ: + return CmpMode::EQ; + case ISD::SETGT: + return CmpMode::GT; + case ISD::SETGE: + return CmpMode::GE; + case ISD::SETLT: + return CmpMode::LT; + case ISD::SETLE: + return CmpMode::LE; + case ISD::SETNE: + return CmpMode::NE; + } + }(CondCode.get()); + + if (FTZ) + PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG; + + return PTXCmpMode; +} + +bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) { + unsigned PTXCmpMode = + getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ()); + SDLoc DL(N); + SDNode *SetP = CurDAG->getMachineNode( + NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0), + N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32)); + ReplaceNode(N, SetP); + return true; +} + +// Find all instances of extract_vector_elt that use this v2f16 vector +// and coalesce them into a scattering move instruction. +bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) { + SDValue Vector = N->getOperand(0); + + // We only care about f16x2 as it's the only real vector type we + // need to deal with. + if (Vector.getSimpleValueType() != MVT::v2f16) + return false; + + // Find and record all uses of this vector that extract element 0 or 1. + SmallVector<SDNode *, 4> E0, E1; + for (const auto &U : Vector.getNode()->uses()) { + if (U->getOpcode() != ISD::EXTRACT_VECTOR_ELT) + continue; + if (U->getOperand(0) != Vector) + continue; + if (const ConstantSDNode *IdxConst = + dyn_cast<ConstantSDNode>(U->getOperand(1))) { + if (IdxConst->getZExtValue() == 0) + E0.push_back(U); + else if (IdxConst->getZExtValue() == 1) + E1.push_back(U); + else + llvm_unreachable("Invalid vector index."); + } + } + + // There's no point scattering f16x2 if we only ever access one + // element of it. + if (E0.empty() || E1.empty()) + return false; + + unsigned Op = NVPTX::SplitF16x2; + // If the vector has been BITCAST'ed from i32, we can use original + // value directly and avoid register-to-register move. + SDValue Source = Vector; + if (Vector->getOpcode() == ISD::BITCAST) { + Op = NVPTX::SplitI32toF16x2; + Source = Vector->getOperand(0); + } + // Merge (f16 extractelt(V, 0), f16 extractelt(V,1)) + // into f16,f16 SplitF16x2(V) + SDNode *ScatterOp = + CurDAG->getMachineNode(Op, SDLoc(N), MVT::f16, MVT::f16, Source); + for (auto *Node : E0) + ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0)); + for (auto *Node : E1) + ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 1)); + + return true; +} + static unsigned int getCodeAddrSpace(MemSDNode *N) { const Value *Src = N->getMemOperand()->getValue(); @@ -689,29 +818,26 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { codeAddrSpace != NVPTX::PTXLdStInstCode::GENERIC) isVolatile = false; - // Vector Setting - MVT SimpleVT = LoadedVT.getSimpleVT(); - unsigned vecType = NVPTX::PTXLdStInstCode::Scalar; - if (SimpleVT.isVector()) { - unsigned num = SimpleVT.getVectorNumElements(); - if (num == 2) - vecType = NVPTX::PTXLdStInstCode::V2; - else if (num == 4) - vecType = NVPTX::PTXLdStInstCode::V4; - else - return false; - } - // Type Setting: fromType + fromTypeWidth // // Sign : ISD::SEXTLOAD // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the // type is integer // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float + MVT SimpleVT = LoadedVT.getSimpleVT(); MVT ScalarVT = SimpleVT.getScalarType(); // Read at least 8 bits (predicates are stored as 8-bit values) unsigned fromTypeWidth = std::max(8U, ScalarVT.getSizeInBits()); unsigned int fromType; + + // Vector Setting + unsigned vecType = NVPTX::PTXLdStInstCode::Scalar; + if (SimpleVT.isVector()) { + assert(LoadedVT == MVT::v2f16 && "Unexpected vector type"); + // v2f16 is loaded using ld.b32 + fromTypeWidth = 32; + } + if ((LD->getExtensionType() == ISD::SEXTLOAD)) fromType = NVPTX::PTXLdStInstCode::Signed; else if (ScalarVT.isFloatingPoint()) @@ -746,6 +872,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::f16: Opcode = NVPTX::LD_f16_avar; break; + case MVT::v2f16: + Opcode = NVPTX::LD_f16x2_avar; + break; case MVT::f32: Opcode = NVPTX::LD_f32_avar; break; @@ -777,6 +906,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::f16: Opcode = NVPTX::LD_f16_asi; break; + case MVT::v2f16: + Opcode = NVPTX::LD_f16x2_asi; + break; case MVT::f32: Opcode = NVPTX::LD_f32_asi; break; @@ -809,6 +941,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::f16: Opcode = NVPTX::LD_f16_ari_64; break; + case MVT::v2f16: + Opcode = NVPTX::LD_f16x2_ari_64; + break; case MVT::f32: Opcode = NVPTX::LD_f32_ari_64; break; @@ -835,6 +970,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::f16: Opcode = NVPTX::LD_f16_ari; break; + case MVT::v2f16: + Opcode = NVPTX::LD_f16x2_ari; + break; case MVT::f32: Opcode = NVPTX::LD_f32_ari; break; @@ -867,6 +1005,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::f16: Opcode = NVPTX::LD_f16_areg_64; break; + case MVT::v2f16: + Opcode = NVPTX::LD_f16x2_areg_64; + break; case MVT::f32: Opcode = NVPTX::LD_f32_areg_64; break; @@ -893,6 +1034,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { case MVT::f16: Opcode = NVPTX::LD_f16_areg; break; + case MVT::v2f16: + Opcode = NVPTX::LD_f16x2_areg; + break; case MVT::f32: Opcode = NVPTX::LD_f32_areg; break; @@ -968,7 +1112,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { if (ExtensionType == ISD::SEXTLOAD) FromType = NVPTX::PTXLdStInstCode::Signed; else if (ScalarVT.isFloatingPoint()) - FromType = NVPTX::PTXLdStInstCode::Float; + FromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped + : NVPTX::PTXLdStInstCode::Float; else FromType = NVPTX::PTXLdStInstCode::Unsigned; @@ -987,6 +1132,16 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { EVT EltVT = N->getValueType(0); + // v8f16 is a special case. PTX doesn't have ld.v8.f16 + // instruction. Instead, we split the vector into v2f16 chunks and + // load them with ld.v4.b32. + if (EltVT == MVT::v2f16) { + assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode."); + EltVT = MVT::i32; + FromType = NVPTX::PTXLdStInstCode::Untyped; + FromTypeWidth = 32; + } + if (SelectDirectAddr(Op1, Addr)) { switch (N->getOpcode()) { default: @@ -1007,6 +1162,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::LDV_i64_v2_avar; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v2_avar; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v2_avar; break; @@ -1028,6 +1186,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::LDV_i32_v4_avar; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v4_avar; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v4_avar; break; @@ -1060,6 +1221,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::LDV_i64_v2_asi; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v2_asi; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v2_asi; break; @@ -1081,6 +1245,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::LDV_i32_v4_asi; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v4_asi; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v4_asi; break; @@ -1114,6 +1281,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::LDV_i64_v2_ari_64; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v2_ari_64; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v2_ari_64; break; @@ -1135,6 +1305,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::LDV_i32_v4_ari_64; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v4_ari_64; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v4_ari_64; break; @@ -1161,6 +1334,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::LDV_i64_v2_ari; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v2_ari; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v2_ari; break; @@ -1182,6 +1358,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::LDV_i32_v4_ari; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v4_ari; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v4_ari; break; @@ -1216,6 +1395,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::LDV_i64_v2_areg_64; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v2_areg_64; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v2_areg_64; break; @@ -1237,6 +1419,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::LDV_i32_v4_areg_64; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v4_areg_64; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v4_areg_64; break; @@ -1263,6 +1448,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::LDV_i64_v2_areg; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v2_areg; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v2_areg; break; @@ -1284,6 +1472,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::LDV_i32_v4_areg; break; + case MVT::f16: + Opcode = NVPTX::LDV_f16_v4_areg; + break; case MVT::f32: Opcode = NVPTX::LDV_f32_v4_areg; break; @@ -2151,21 +2342,18 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { // Vector Setting MVT SimpleVT = StoreVT.getSimpleVT(); unsigned vecType = NVPTX::PTXLdStInstCode::Scalar; - if (SimpleVT.isVector()) { - unsigned num = SimpleVT.getVectorNumElements(); - if (num == 2) - vecType = NVPTX::PTXLdStInstCode::V2; - else if (num == 4) - vecType = NVPTX::PTXLdStInstCode::V4; - else - return false; - } // Type Setting: toType + toTypeWidth // - for integer type, always use 'u' // MVT ScalarVT = SimpleVT.getScalarType(); unsigned toTypeWidth = ScalarVT.getSizeInBits(); + if (SimpleVT.isVector()) { + assert(StoreVT == MVT::v2f16 && "Unexpected vector type"); + // v2f16 is stored using st.b32 + toTypeWidth = 32; + } + unsigned int toType; if (ScalarVT.isFloatingPoint()) // f16 uses .b16 as its storage type. @@ -2200,6 +2388,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::f16: Opcode = NVPTX::ST_f16_avar; break; + case MVT::v2f16: + Opcode = NVPTX::ST_f16x2_avar; + break; case MVT::f32: Opcode = NVPTX::ST_f32_avar; break; @@ -2232,6 +2423,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::f16: Opcode = NVPTX::ST_f16_asi; break; + case MVT::v2f16: + Opcode = NVPTX::ST_f16x2_asi; + break; case MVT::f32: Opcode = NVPTX::ST_f32_asi; break; @@ -2265,6 +2459,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::f16: Opcode = NVPTX::ST_f16_ari_64; break; + case MVT::v2f16: + Opcode = NVPTX::ST_f16x2_ari_64; + break; case MVT::f32: Opcode = NVPTX::ST_f32_ari_64; break; @@ -2291,6 +2488,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::f16: Opcode = NVPTX::ST_f16_ari; break; + case MVT::v2f16: + Opcode = NVPTX::ST_f16x2_ari; + break; case MVT::f32: Opcode = NVPTX::ST_f32_ari; break; @@ -2324,6 +2524,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::f16: Opcode = NVPTX::ST_f16_areg_64; break; + case MVT::v2f16: + Opcode = NVPTX::ST_f16x2_areg_64; + break; case MVT::f32: Opcode = NVPTX::ST_f32_areg_64; break; @@ -2350,6 +2553,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) { case MVT::f16: Opcode = NVPTX::ST_f16_areg; break; + case MVT::v2f16: + Opcode = NVPTX::ST_f16x2_areg; + break; case MVT::f32: Opcode = NVPTX::ST_f32_areg; break; @@ -2411,7 +2617,8 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { unsigned ToTypeWidth = ScalarVT.getSizeInBits(); unsigned ToType; if (ScalarVT.isFloatingPoint()) - ToType = NVPTX::PTXLdStInstCode::Float; + ToType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped + : NVPTX::PTXLdStInstCode::Float; else ToType = NVPTX::PTXLdStInstCode::Unsigned; @@ -2438,6 +2645,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { return false; } + // v8f16 is a special case. PTX doesn't have st.v8.f16 + // instruction. Instead, we split the vector into v2f16 chunks and + // store them with st.v4.b32. + if (EltVT == MVT::v2f16) { + assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode."); + EltVT = MVT::i32; + ToType = NVPTX::PTXLdStInstCode::Untyped; + ToTypeWidth = 32; + } + StOps.push_back(getI32Imm(IsVolatile, DL)); StOps.push_back(getI32Imm(CodeAddrSpace, DL)); StOps.push_back(getI32Imm(VecType, DL)); @@ -2464,6 +2681,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::STV_i64_v2_avar; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v2_avar; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v2_avar; break; @@ -2513,6 +2733,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::STV_i64_v2_asi; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v2_asi; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v2_asi; break; @@ -2534,6 +2757,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::STV_i32_v4_asi; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v4_asi; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v4_asi; break; @@ -2564,6 +2790,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::STV_i64_v2_ari_64; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v2_ari_64; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v2_ari_64; break; @@ -2585,6 +2814,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::STV_i32_v4_ari_64; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v4_ari_64; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v4_ari_64; break; @@ -2611,6 +2843,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::STV_i64_v2_ari; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v2_ari; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v2_ari; break; @@ -2632,6 +2867,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::STV_i32_v4_ari; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v4_ari; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v4_ari; break; @@ -2662,6 +2900,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::STV_i64_v2_areg_64; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v2_areg_64; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v2_areg_64; break; @@ -2683,6 +2924,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::STV_i32_v4_areg_64; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v4_areg_64; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v4_areg_64; break; @@ -2709,6 +2953,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i64: Opcode = NVPTX::STV_i64_v2_areg; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v2_areg; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v2_areg; break; @@ -2730,6 +2977,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { case MVT::i32: Opcode = NVPTX::STV_i32_v4_areg; break; + case MVT::f16: + Opcode = NVPTX::STV_f16_v4_areg; + break; case MVT::f32: Opcode = NVPTX::STV_f32_v4_areg; break; @@ -2804,6 +3054,9 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) { case MVT::f16: Opc = NVPTX::LoadParamMemF16; break; + case MVT::v2f16: + Opc = NVPTX::LoadParamMemF16x2; + break; case MVT::f32: Opc = NVPTX::LoadParamMemF32; break; @@ -2831,6 +3084,12 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) { case MVT::i64: Opc = NVPTX::LoadParamMemV2I64; break; + case MVT::f16: + Opc = NVPTX::LoadParamMemV2F16; + break; + case MVT::v2f16: + Opc = NVPTX::LoadParamMemV2F16x2; + break; case MVT::f32: Opc = NVPTX::LoadParamMemV2F32; break; @@ -2855,6 +3114,12 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) { case MVT::i32: Opc = NVPTX::LoadParamMemV4I32; break; + case MVT::f16: + Opc = NVPTX::LoadParamMemV4F16; + break; + case MVT::v2f16: + Opc = NVPTX::LoadParamMemV4F16x2; + break; case MVT::f32: Opc = NVPTX::LoadParamMemV4F32; break; @@ -2942,6 +3207,9 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) { case MVT::f16: Opcode = NVPTX::StoreRetvalF16; break; + case MVT::v2f16: + Opcode = NVPTX::StoreRetvalF16x2; + break; case MVT::f32: Opcode = NVPTX::StoreRetvalF32; break; @@ -2969,6 +3237,12 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) { case MVT::i64: Opcode = NVPTX::StoreRetvalV2I64; break; + case MVT::f16: + Opcode = NVPTX::StoreRetvalV2F16; + break; + case MVT::v2f16: + Opcode = NVPTX::StoreRetvalV2F16x2; + break; case MVT::f32: Opcode = NVPTX::StoreRetvalV2F32; break; @@ -2993,6 +3267,12 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) { case MVT::i32: Opcode = NVPTX::StoreRetvalV4I32; break; + case MVT::f16: + Opcode = NVPTX::StoreRetvalV4F16; + break; + case MVT::v2f16: + Opcode = NVPTX::StoreRetvalV4F16x2; + break; case MVT::f32: Opcode = NVPTX::StoreRetvalV4F32; break; @@ -3000,8 +3280,7 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) { break; } - SDNode *Ret = - CurDAG->getMachineNode(Opcode, DL, MVT::Other, Ops); + SDNode *Ret = CurDAG->getMachineNode(Opcode, DL, MVT::Other, Ops); MachineSDNode::mmo_iterator MemRefs0 = MF->allocateMemRefsArray(1); MemRefs0[0] = cast<MemSDNode>(N)->getMemOperand(); cast<MachineSDNode>(Ret)->setMemRefs(MemRefs0, MemRefs0 + 1); @@ -3078,6 +3357,9 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) { case MVT::f16: Opcode = NVPTX::StoreParamF16; break; + case MVT::v2f16: + Opcode = NVPTX::StoreParamF16x2; + break; case MVT::f32: Opcode = NVPTX::StoreParamF32; break; @@ -3105,6 +3387,12 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) { case MVT::i64: Opcode = NVPTX::StoreParamV2I64; break; + case MVT::f16: + Opcode = NVPTX::StoreParamV2F16; + break; + case MVT::v2f16: + Opcode = NVPTX::StoreParamV2F16x2; + break; case MVT::f32: Opcode = NVPTX::StoreParamV2F32; break; @@ -3129,6 +3417,12 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) { case MVT::i32: Opcode = NVPTX::StoreParamV4I32; break; + case MVT::f16: + Opcode = NVPTX::StoreParamV4F16; + break; + case MVT::v2f16: + Opcode = NVPTX::StoreParamV4F16x2; + break; case MVT::f32: Opcode = NVPTX::StoreParamV4F32; break; |