diff options
author | Artem Belevich <tra@google.com> | 2018-03-15 21:40:56 +0000 |
---|---|---|
committer | Artem Belevich <tra@google.com> | 2018-03-15 21:40:56 +0000 |
commit | 6a06998de9cf8d5885a37fcd06c42a65b1da34e9 (patch) | |
tree | ad17f876b045a85e7cca82ec874d1d928027e5b7 /lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | |
parent | 648727d7b98bacb615a8be81ea0b4b74b6d8f409 (diff) |
[NVPTX] TblGen-ized lowering of WMMA intrinsics.
NFC.
Differential Revision: https://reviews.llvm.org/D43151
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@327672 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp')
-rw-r--r-- | lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 518 |
1 files changed, 6 insertions, 512 deletions
diff --git a/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index d831f9f19a5..99305440eef 100644 --- a/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -496,318 +496,8 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) { SelectCode(N); } -// Each instruction has four addressing variants. WMMA_VARIANTS() macro below -// constructs an array indexed by WmmaVariant which getWmmaLdVariant() uses to -// look up the intrinsic ID of particular variant. -enum WmmaVariant { - WMMA_VARIANT_ARI64, - WMMA_VARIANT_ARI64_STRIDE, - WMMA_VARIANT_AVAR, - WMMA_VARIANT_AVAR_STRIDE, -}; - -// clang-format off -#define WMMA_VARIANTS(base) \ - {{ base##_ari64, base##_ari64_stride, base##_avar, base##_avar_stride }} -// clang-format on - -static unsigned getWmmaLdVariant(WmmaVariant Variant, bool Stride, - const std::array<unsigned, 4> Variants) { - if (Stride) { - if (Variant == WMMA_VARIANT_ARI64) - Variant = WMMA_VARIANT_ARI64_STRIDE; - else if (Variant == WMMA_VARIANT_AVAR) - Variant = WMMA_VARIANT_AVAR_STRIDE; - } - return Variants[Variant]; -} - -static Optional<unsigned> -getWmmaLdStOpcode(unsigned IntrinsicID, - WmmaVariant Variant = WMMA_VARIANT_ARI64) { - switch (IntrinsicID) { - default: - return None; - // - // WMMA_LOAD_A f16 - // - case Intrinsic::nvvm_wmma_load_a_f16_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col)); - case Intrinsic::nvvm_wmma_load_a_f16_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row)); - case Intrinsic::nvvm_wmma_load_a_f16_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col)); - case Intrinsic::nvvm_wmma_load_a_f16_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row)); - case Intrinsic::nvvm_wmma_load_a_f16_col_shared: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared)); - case Intrinsic::nvvm_wmma_load_a_f16_row_shared: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared)); - case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared)); - case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared)); - case Intrinsic::nvvm_wmma_load_a_f16_col_global: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global)); - case Intrinsic::nvvm_wmma_load_a_f16_row_global: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global)); - case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global)); - case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global)); - - // - // WMMA_LOAD_B f16 - // - case Intrinsic::nvvm_wmma_load_b_f16_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col)); - case Intrinsic::nvvm_wmma_load_b_f16_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row)); - case Intrinsic::nvvm_wmma_load_b_f16_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col)); - case Intrinsic::nvvm_wmma_load_b_f16_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row)); - case Intrinsic::nvvm_wmma_load_b_f16_col_shared: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared)); - case Intrinsic::nvvm_wmma_load_b_f16_row_shared: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared)); - case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared)); - case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared)); - case Intrinsic::nvvm_wmma_load_b_f16_col_global: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global)); - case Intrinsic::nvvm_wmma_load_b_f16_row_global: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global)); - case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global)); - case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global)); - - // - // WMMA_LOAD_C f16 - // - case Intrinsic::nvvm_wmma_load_c_f16_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col)); - case Intrinsic::nvvm_wmma_load_c_f16_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row)); - case Intrinsic::nvvm_wmma_load_c_f16_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col)); - case Intrinsic::nvvm_wmma_load_c_f16_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row)); - case Intrinsic::nvvm_wmma_load_c_f16_col_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared)); - case Intrinsic::nvvm_wmma_load_c_f16_row_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared)); - case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared)); - case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared)); - case Intrinsic::nvvm_wmma_load_c_f16_col_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global)); - case Intrinsic::nvvm_wmma_load_c_f16_row_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global)); - case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global)); - case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global)); - - // - // WMMA_LOAD_C f32 - // - case Intrinsic::nvvm_wmma_load_c_f32_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col)); - case Intrinsic::nvvm_wmma_load_c_f32_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row)); - case Intrinsic::nvvm_wmma_load_c_f32_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col)); - case Intrinsic::nvvm_wmma_load_c_f32_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row)); - case Intrinsic::nvvm_wmma_load_c_f32_col_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared)); - case Intrinsic::nvvm_wmma_load_c_f32_row_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared)); - case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared)); - case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared)); - case Intrinsic::nvvm_wmma_load_c_f32_col_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global)); - case Intrinsic::nvvm_wmma_load_c_f32_row_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global)); - case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global)); - case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global)); - - // - // WMMA_STORE_D f16 - // - case Intrinsic::nvvm_wmma_store_d_f16_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col)); - case Intrinsic::nvvm_wmma_store_d_f16_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row)); - case Intrinsic::nvvm_wmma_store_d_f16_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col)); - case Intrinsic::nvvm_wmma_store_d_f16_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row)); - case Intrinsic::nvvm_wmma_store_d_f16_col_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared)); - case Intrinsic::nvvm_wmma_store_d_f16_row_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared)); - case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared)); - case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared)); - case Intrinsic::nvvm_wmma_store_d_f16_col_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global)); - case Intrinsic::nvvm_wmma_store_d_f16_row_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global)); - case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global)); - case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global)); - - // - // WMMA_STORE_D f32 - // - case Intrinsic::nvvm_wmma_store_d_f32_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col)); - case Intrinsic::nvvm_wmma_store_d_f32_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row)); - case Intrinsic::nvvm_wmma_store_d_f32_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col)); - case Intrinsic::nvvm_wmma_store_d_f32_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row)); - case Intrinsic::nvvm_wmma_store_d_f32_col_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared)); - case Intrinsic::nvvm_wmma_store_d_f32_row_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared)); - case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared)); - case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared)); - case Intrinsic::nvvm_wmma_store_d_f32_col_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global)); - case Intrinsic::nvvm_wmma_store_d_f32_row_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global)); - case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global)); - case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global)); - } -} -#undef WMMA_VARIANTS - bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) { unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); - if (getWmmaLdStOpcode(IID)) - return tryWMMA_LDST(N); - switch (IID) { default: return false; @@ -1026,39 +716,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) { case Intrinsic::nvvm_texsurf_handle_internal: SelectTexSurfHandle(N); return true; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite: - return tryWMMA_MMA(N); } } @@ -3946,6 +3603,12 @@ bool NVPTXDAGToDAGISel::SelectADDRri64(SDNode *OpNode, SDValue Addr, return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i64); } +// symbol +bool NVPTXDAGToDAGISel::SelectADDRvar(SDNode *OpNode, SDValue Addr, + SDValue &Value) { + return SelectDirectAddr(Addr, Value); +} + bool NVPTXDAGToDAGISel::ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const { const Value *Src = nullptr; @@ -4038,172 +3701,3 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy, } } } - -bool NVPTXDAGToDAGISel::tryWMMA_LDST(SDNode *N) { - SDValue Chain = N->getOperand(0); - unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); - SDValue Op1 = N->getOperand(2); - SDValue Addr, Offset, Base; - Optional<unsigned> Opcode; - SDLoc DL(N); - MemSDNode *MemSD = cast<MemIntrinsicSDNode>(N); - WmmaVariant Variant; - SmallVector<SDValue, 12> Ops; - bool isStore = N->getNumValues() == 1; // Store ops only return a chain. - - if (SelectDirectAddr(Op1, Addr)) { - Variant = WMMA_VARIANT_AVAR; - Ops.push_back(Addr); - } else if (SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) || - SelectADDRri64(Op1.getNode(), Op1, Base, Offset)) { - Variant = WMMA_VARIANT_ARI64; - Ops.push_back(Base); - Ops.push_back(Offset); - } else { - Variant = WMMA_VARIANT_AVAR; - Ops.push_back(Op1); - } - unsigned NumOps = N->getNumOperands(); - // Pass through the rest of the operands to the machine node. - for (unsigned i = 3; i < NumOps; ++i) - Ops.push_back(N->getOperand(i)); - Ops.push_back(Chain); - - Opcode = getWmmaLdStOpcode(IID, Variant); - if (!Opcode) { - llvm::errs() << "tryWMMALD - no Opcode.\n"; - return false; - } - - EVT MemVT = MemSD->getMemoryVT(); - assert(MemVT.isVector() && "Expected vector return type."); - - SDNode *MN; - if (isStore) { - MN = CurDAG->getMachineNode(Opcode.getValue(), DL, MVT::Other, Ops); - } else { - SmallVector<EVT, 9> InstVTs(MemVT.getVectorNumElements(), - MemSD->getValueType(0)); - InstVTs.push_back(MVT::Other); - MN = CurDAG->getMachineNode(Opcode.getValue(), DL, InstVTs, Ops); - } - - ReplaceNode(N, MN); - return true; -} - -bool NVPTXDAGToDAGISel::tryWMMA_MMA(SDNode *N) { - unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue(); - SDLoc DL(N); - unsigned Opc; - - switch (IID) { - default: - return false; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16: - Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32: - Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16: - Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32: - Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16: - Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32: - Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16: - Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32: - Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16: - Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32: - Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16: - Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32: - Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16: - Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32: - Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16: - Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32: - Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32_satfinite; - break; - } - - SmallVector<SDValue, 24> Ops; - // Pass through operands and return value types to the machine node. - for (unsigned i = 1; i < N->getNumOperands(); ++i) - Ops.push_back(N->getOperand(i)); - SmallVector<EVT, 8> InstVTs(N->getNumValues(), N->getValueType(0)); - SDNode *MN = CurDAG->getMachineNode(Opc, DL, InstVTs, Ops); - ReplaceNode(N, MN); - return true; -} |