/*===--------------------------------------------------------------------------
 *                   ROCm Device Libraries
 *
 * This file is distributed under the University of Illinois Open Source
 * License. See LICENSE.TXT for details.
 *===------------------------------------------------------------------------*/

#include "irif.h"
#include "device_amd_hsa.h"

#define ATTR __attribute__((const))

ATTR size_t
__ockl_get_global_offset(uint dim)
{
    // TODO find out if implicit arg pointer is aligned properly
    switch(dim) {
    case 0:
        return *(__constant size_t *)__builtin_amdgcn_implicitarg_ptr();
    case 1:
        return ((__constant size_t *)__builtin_amdgcn_implicitarg_ptr())[1];
    case 2:
        return ((__constant size_t *)__builtin_amdgcn_implicitarg_ptr())[2];
    default:
        return 0;
    }
}

ATTR size_t
__ockl_get_global_id(uint dim)
{
    uint l, g, s;

    switch(dim) {
    case 0:
        l = __builtin_amdgcn_workitem_id_x();
        g = __builtin_amdgcn_workgroup_id_x();
        s = __builtin_amdgcn_workgroup_size_x();
        break;
    case 1:
        l = __builtin_amdgcn_workitem_id_y();
        g = __builtin_amdgcn_workgroup_id_y();
        s = __builtin_amdgcn_workgroup_size_y();
        break;
    case 2:
        l = __builtin_amdgcn_workitem_id_z();
        g = __builtin_amdgcn_workgroup_id_z();
        s = __builtin_amdgcn_workgroup_size_z();
        break;
    default:
        l = 0;
        g = 0;
        s = 1;
        break;
    }

    return (g*s + l) + __ockl_get_global_offset(dim);
}

ATTR size_t
__ockl_get_local_id(uint dim)
{
    switch(dim) {
    case 0:
        return __builtin_amdgcn_workitem_id_x();
    case 1:
        return __builtin_amdgcn_workitem_id_y();
    case 2:
        return __builtin_amdgcn_workitem_id_z();
    default:
        return 0;
    }
}

ATTR size_t
__ockl_get_group_id(uint dim)
{
    switch(dim) {
    case 0:
        return __builtin_amdgcn_workgroup_id_x();
    case 1:
        return __builtin_amdgcn_workgroup_id_y();
    case 2:
        return __builtin_amdgcn_workgroup_id_z();
    default:
        return 0;
    }
}

ATTR size_t
__ockl_get_global_size(uint dim)
{
    __constant hsa_kernel_dispatch_packet_t *p = __builtin_amdgcn_dispatch_ptr();

    switch(dim) {
    case 0:
        return p->grid_size_x;
    case 1:
        return p->grid_size_y;
    case 2:
        return p->grid_size_z;
    default:
        return 1;
    }
}

ATTR size_t
__ockl_get_local_size(uint dim)
{
    __constant hsa_kernel_dispatch_packet_t *p = __builtin_amdgcn_dispatch_ptr();
    uint group_id, grid_size, group_size;

    switch(dim) {
    case 0:
        group_id = __builtin_amdgcn_workgroup_id_x();
        group_size = __builtin_amdgcn_workgroup_size_x();
        grid_size = p->grid_size_x;
        break;
    case 1:
        group_id = __builtin_amdgcn_workgroup_id_y();
        group_size = __builtin_amdgcn_workgroup_size_y();
        grid_size = p->grid_size_y;
        break;
    case 2:
        group_id = __builtin_amdgcn_workgroup_id_z();
        group_size = __builtin_amdgcn_workgroup_size_z();
        grid_size = p->grid_size_z;
        break;
    default:
        group_id = 0;
        grid_size = 0;
        group_size = 1;
        break;
    }
    uint r = grid_size - group_id * group_size;
    return (r < group_size) ? r : group_size;
}

ATTR size_t
__ockl_get_num_groups(uint dim)
{
    __constant hsa_kernel_dispatch_packet_t *p = __builtin_amdgcn_dispatch_ptr();

    uint n, d;
    switch(dim) {
    case 0:
        n = p->grid_size_x;
        d = __builtin_amdgcn_workgroup_size_x();
        break;
    case 1:
        n = p->grid_size_y;
        d = __builtin_amdgcn_workgroup_size_y();
        break;
    case 2:
        n = p->grid_size_z;
        d = __builtin_amdgcn_workgroup_size_z();
        break;
    default:
        n = 1;
        d = 1;
        break;
    }

    uint q = n / d;

    return q + (n > q*d);
}

ATTR uint
__ockl_get_work_dim(void) {
    __constant hsa_kernel_dispatch_packet_t *p = __builtin_amdgcn_dispatch_ptr();
    // XXX revist this if setup field ever changes
    return p->setup;
}

ATTR size_t
__ockl_get_enqueued_local_size(uint dim)
{
    switch(dim) {
    case 0:
        return __builtin_amdgcn_workgroup_size_x();
    case 1:
        return __builtin_amdgcn_workgroup_size_y();
    case 2:
        return __builtin_amdgcn_workgroup_size_z();
    default:
        return 1;
    }
}

ATTR size_t
__ockl_get_global_linear_id(void)
{
    __constant hsa_kernel_dispatch_packet_t *p = __builtin_amdgcn_dispatch_ptr();

    // XXX revisit this if setup field ever changes
    switch (p->setup) {
    case 1:
        {
            uint l0 = __builtin_amdgcn_workitem_id_x();
            uint g0 = __builtin_amdgcn_workgroup_id_x();
            uint s0 = __builtin_amdgcn_workgroup_size_x();
            return g0*s0 + l0;
        }
    case 2:
        {
            uint l0 = __builtin_amdgcn_workitem_id_x();
            uint l1 = __builtin_amdgcn_workitem_id_y();
            uint g0 = __builtin_amdgcn_workgroup_id_x();
            uint g1 = __builtin_amdgcn_workgroup_id_y();
            uint s0 = __builtin_amdgcn_workgroup_size_x();
            uint s1 = __builtin_amdgcn_workgroup_size_y();
            uint n0 = p->grid_size_x;
            uint i0 = g0*s0 + l0;
            uint i1 = g1*s1 + l1;
            return (size_t)i1 * (size_t)n0 + i0;
        }
    case 3:
        {
            uint l0 = __builtin_amdgcn_workitem_id_x();
            uint l1 = __builtin_amdgcn_workitem_id_y();
            uint l2 = __builtin_amdgcn_workitem_id_z();
            uint g0 = __builtin_amdgcn_workgroup_id_x();
            uint g1 = __builtin_amdgcn_workgroup_id_y();
            uint g2 = __builtin_amdgcn_workgroup_id_z();
            uint s0 = __builtin_amdgcn_workgroup_size_x();
            uint s1 = __builtin_amdgcn_workgroup_size_y();
            uint s2 = __builtin_amdgcn_workgroup_size_z();
            uint n0 = p->grid_size_x;
            uint n1 = p->grid_size_y;
            uint i0 = g0*s0 + l0;
            uint i1 = g1*s1 + l1;
            uint i2 = g2*s2 + l2;
            return ((size_t)i2 * (size_t)n1 + (size_t)i1) * (size_t)n0 + i0;
        }
    default:
        return 0;
    }
}

ATTR size_t
__ockl_get_local_linear_id(void)
{
    return (__builtin_amdgcn_workitem_id_z() * __builtin_amdgcn_workgroup_size_y() +
            __builtin_amdgcn_workitem_id_y()) * __builtin_amdgcn_workgroup_size_x() +
           __builtin_amdgcn_workitem_id_x();
}

