// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"

#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {

using I8   = int8_t;
using BF16 = bhalf_t;
using F32  = float;

using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;

template <index_t... Is>
using S = Sequence<Is...>;

using PassThrough = element_wise::PassThrough;

using DsLayout   = ck::Tuple<>;
using DsDataType = ck::Tuple<>;

static constexpr auto GemmDefault    = GemmSpecialization::Default;
static constexpr auto GemmKPadding   = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding  = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;

static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;

template <GemmSpecialization GemmSpec,
          typename DsLayout   = ck::Tuple<>,
          typename DsDataType = ck::Tuple<>>
using device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances =
    std::tuple<
        // clang-format off
        //#########################| ALayout| BLayout|  DsLayout| CLayout| AData| BData|      DsData| CData| AccData| Cshuffle|           A|           B|           C|          GEMM| Block|  MPer|  NPer|  KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat|  ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|  BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|  CShuffle|  CShuffle|         CBlockTransferClusterLengths|  CBlockTransfer|                         Block-wiseGemm|               Block-wiseGemm|      Reduce|
        //#########################|        |        |          |        |  Type|  Type|        Type|  Type|    Type|     Type| Elementwise| Elementwise| Elementwise|Specialization|  Size| Block| Block| Block|    |    |        |        |       |       |   ThreadCluster|  ThreadCluster| SrcAccessOrder|   SrcVectorDim|      SrcScalar|      DstScalar| AddExtraM|   ThreadCluster|  ThreadCluster| SrcAccessOrder|  SrcVectorDim|      SrcScalar|      DstScalar| AddExtraN|MRepeatPer|NRepeatPer|     _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector|                               Pipeline|                     Pipeline|    DataType|
        //#########################|        |        |          |        |      |      |            |      |        |         |   Operation|   Operation|   Operation|              |      |      |      |      |    |    |        |        |       |       | Lengths_K0_M_K1|   ArrangeOrder|               |               |      PerVector|   PerVector_K1|          | Lengths_K0_N_K1|   ArrangeOrder|               |              |      PerVector|   PerVector_K1|          | Shuffle  | Shuffle  |  PerShuffle_NBlock_NRepeatPerShuffle|   _NPerBlock   |                              Scheduler|                      Version|            |
        //#########################|        |        |          |        |      |      |            |      |        |         |            |            |            |              |      |      |      |      |    |    |        |        |       |       |                |               |               |               |               |               |          |                |               |               |              |               |               |          |          |          | _NWaveNPerRepeat                    |                |                                       |                             |            |
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   128,    32,   8,   4,      16,      16,      4,      2,     S<4, 64, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 64, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   256,    32,   8,   4,      16,      16,      4,      4,     S<4, 64, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 64, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   256,   128,    32,   8,   4,      16,      16,      8,      2,     S<4, 64, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 64, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   128,    64,   8,   4,      16,      16,      4,      2,     S<8, 32, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<8, 32, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   256,    64,   8,   4,      16,      16,      4,      4,     S<8, 32, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<8, 32, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   128,    64,    64,    32,   8,   4,      16,      16,      2,      2,     S<4, 32, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 32, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              2,              2,     true,          1,         1,                       S<1, 32, 1, 2>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   128,   128,    64,    64,   8,   4,      16,      16,      4,      2,     S<4, 32, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 32, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 4>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   256,    32,   8,   4,      16,      16,      4,      4,     S<4, 64, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 64, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Interwave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   128,    64,   8,   4,      16,      16,      4,      2,     S<8, 32, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<8, 32, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Interwave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   128,   128,   128,    32,   8,   4,      16,      16,      4,      4,     S<4, 32, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 32, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 4>,               8,  BlockGemmPipelineScheduler::Interwave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   128,    32,   8,   4,      16,      16,      4,      2,     S<4, 64, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 64, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Interwave,  BlockGemmPipelineVersion::v1,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   128,    32,   8,   4,      16,      16,      4,      2,     S<4, 64, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 64, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v3,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   256,    32,   8,   4,      16,      16,      4,      4,     S<4, 64, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 64, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v3,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   256,   128,    32,   8,   4,      16,      16,      8,      2,     S<4, 64, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 64, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v3,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   128,    64,   8,   4,      16,      16,      4,      2,     S<8, 32, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<8, 32, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v3,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   256,   128,   256,    64,   8,   4,      16,      16,      4,      4,     S<8, 32, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<8, 32, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 8>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v3,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   128,    64,    64,    32,   8,   4,      16,      16,      2,      2,     S<4, 32, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 32, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              2,              2,     true,          1,         1,                       S<1, 32, 1, 2>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v3,      float>,
        DeviceGemm_Wmma_CShuffleV3R1<    Row,     Row,  DsLayout,     Row,  BF16,    I8, DsDataType,  BF16,     F32,     BF16, PassThrough, PassThrough, PassThrough,      GemmSpec,   128,   128,    64,    64,   8,   4,      16,      16,      4,      2,     S<4, 32, 1>,     S<0, 2, 1>,    S<0, 2, 1>,               1,              1,              8,       true,    S<4, 32, 1>,     S<0, 2, 1>,     S<0, 2, 1>,             1,              1,              4,     true,          1,         1,                       S<1, 32, 1, 4>,               8,  BlockGemmPipelineScheduler::Intrawave,  BlockGemmPipelineVersion::v3,      float>
        // clang-format on
        >;

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
