// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "fused_moesorting.hpp"
#include "fused_moegemm.hpp"

struct fused_moe_args
{
    const void* a_ptr;              // [m, k], input token
    const void* a_scale_ptr;        // [m, 1], token scale
    const void* g_ptr;              // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
    const void* d_ptr;              // [e, n, k], pre-shuffle([e, nr, kr, w])
    const void* g_scale_ptr;        // [e, 1, n], gate(up) scale
    const void* d_scale_ptr;        // [e, 1, k], down scale
    const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
    void* o_ptr;                    // [m, k], output token (no need to do zeroing)

    const void* topk_ids_ptr;    // [tokens, topk]
    const void* topk_weight_ptr; // [tokens, topk]
    void* sorted_token_ids_ptr;  // [max_num_tokens_padded]
    void* sorted_weight_ptr;     // [max_num_tokens_padded]
    void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
    void* num_sorted_tiles_ptr;  // [1]

    ck_tile::index_t block_m;           // block_m, used to devide the input
    ck_tile::index_t hidden_size;       // k
    ck_tile::index_t intermediate_size; // n / TP, for Gate. and Up, Down is also this value
    ck_tile::index_t num_tokens;        // input number of tokens for current iteration
    ck_tile::index_t num_experts;       // number of groups
    ck_tile::index_t topk;              // need this?

    ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};

// This is the public API, will be generated by script
struct fused_moe_traits
{
    std::string prec_i;  // input precision
    std::string prec_w;  // weight precision
    std::string prec_o;  // output precision
    std::string prec_st; // token scale data type
    std::string prec_sw; // weight scale data type
    std::string prec_sq; // smooth quant scale
    std::string prec_kw; // topk-weight data type
    int block_m;
    int activation;  // 0:gelu, 1:silu
    int gate_only;   // 0:g1u0, 1:g1u1
    int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};

float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
