// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier:  MIT

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>

template <typename InputType, typename OutputType>
struct MoeSmoothquantTypeConfig
{
    using XDataType           = InputType;
    using SmoothScaleDataType = float;
    using YScaleDataType      = float;
    using QYDataType          = OutputType;
    using ComputeDataType     = float;
};

// runtime args
struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
{
};

// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename InputType_,
          typename OutputType_,
          ck_tile::index_t Repeat_M_,         // each thread repeat along M
          ck_tile::index_t Repeat_N_,         // each thread repeat along N
          ck_tile::index_t ThreadPerBlock_M_, // num threads along M
          ck_tile::index_t ThreadPerBlock_N_, // num threads along N
          ck_tile::index_t Vector_N_,         // vector size along N
          bool kPadN_,
          bool kTwoPass_>
struct moe_smoothquant_traits_
{
    using InputType  = ck_tile::remove_cvref_t<InputType_>;
    using OutputType = ck_tile::remove_cvref_t<OutputType_>;

    static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
    static constexpr ck_tile::index_t Repeat_N = Repeat_N_;

    static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
    static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;

    using BlockTile      = ck_tile::sequence<Block_M, Block_N>;
    using Vector         = ck_tile::sequence<1, Vector_N_>;
    using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
    using Shape          = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;

    static constexpr bool kPadN    = kPadN_;
    static constexpr bool kTwoPass = kTwoPass_;
};

template <typename Traits_>
float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);

// This is the public API, will be generated by script
template <typename InputType, typename OutputType>
float moe_smoothquant(moe_smoothquant_args, const ck_tile::stream_config&);
