//
// Created by 乾三 on 2023/4/18.
//

#include "../include/centerfacePostprocess.h"

namespace {
    const int kwidth = 640;
    const int kheight = 448;

//    enum BBoxIndex { TOPLEFTX = 0, TOPLEFTY = 1, BOTTOMRIGHTX = 2, BOTTOMRIGHTY = 3, SCORE = 4, LABEL = 5 };

}

centerfacePostprocessThread::centerfacePostprocessThread() {

}

std::vector<BBoxstr> centerfacePostprocessThread::outPutDecode(float* heatmap, float* scale, float* offset, float* landmark, int width, int height, std::vector<std::vector<float>>& lms,int oriwidth,int oriheight) {

    std::vector<BBoxstr> detectResults;

    int len = width * height;
    int c0, c1;
    float s0, s1, o0, o1, s, x1, y1,widthScale,heightScale;

    lms.clear();

    std::vector<float> lm;

    for (int i = 0; i < len; i++) {
        if (heatmap[i] > threshold_) {
            widthScale =  float(kwidth) / float(oriwidth);
            heightScale = float(kheight) / float(oriheight);
            BBoxstr boundBox;
            c0 = i / width;
            c1 = i % width;
            s0 = exp(scale[c0 * width + c1]) * 4 ;
            s1 = exp(scale[len + c0 * width + c1]) * 4 ;
            o0 = offset[c0 * width + c1];
            o1 = offset[len + c0 * width + c1];
            s = heatmap[i];
            x1 = std::max(0.f, (c1 + o1 + 0.5f) * 4  - s1 / 2);
            y1 = std::max(0.f, (c0 + o0 + 0.5f) * 4  - s0 / 2);
            x1 = std::min(x1, (float)kwidth);
            y1 = std::min(y1, (float)kheight);


//            if (widthScale < heightScale){ // y 轴需要补充 x轴缩放
//                boundBox.rect.ltX = x1 / widthScale;
//                (1 - (oriheight * widthScale) / kheight)
////                boundBox.rect.ltY = y1;
//                boundBox.rect.ltY = (int)(y1 - ((kheight - (oriheight / widthScale))/2));
//                boundBox.rect.rbX = std::min(x1 + s1, (float)kwidth);
//                boundBox.rect.rbY = std::min(y1 + s0, (float)kheight);
//                boundBox.cls = 0;
//                boundBox.score = s;
//
//            } else{
//                boundBox.rect.ltX = x1;
//                boundBox.rect.ltY = y1;
//                boundBox.rect.rbX = std::min(x1 + s1, (float)kwidth);
//                boundBox.rect.rbY = std::min(y1 + s0, (float)kheight);
//                boundBox.cls = 0;
//                boundBox.score = s;
//            }
            if (heightScale > widthScale)
            {
                boundBox.rect.ltX = std::max((int)(x1 / widthScale),1);
                boundBox.rect.rbX = std::min((int)((x1 + s1) / widthScale),oriwidth) ;
                boundBox.rect.ltY = std::max((int)((((y1) - (kheight - widthScale * oriheight) / 2)) / widthScale),1);
                boundBox.rect.rbY = std::min((int)((((y1 + s0) - (kheight - widthScale * oriheight) / 2)) / widthScale),oriheight - 1);

            }
            else
            {
                boundBox.rect.ltX = std::max((int)(((x1) - (kwidth - heightScale * oriwidth) / 2) / heightScale),1);
                boundBox.rect.rbX = std::min((int)( ((x1 + s1) - (kwidth - heightScale * oriwidth) / 2) / heightScale),oriwidth-1);
                boundBox.rect.ltY = std::max((int)((y1) / heightScale),0);
                boundBox.rect.rbY = std::min((int)((y1 + s0) / heightScale),oriheight);
            }
            boundBox.cls = 0;
            boundBox.score = s;
            detectResults.push_back(boundBox);

//            if (landmarks) {
//                lm.clear();
//                for (int j = 0; j < 5; j++) {
//                    lm.push_back(landmark[j * 2 + 1 * len + c0 * width + c1] * s1 + x1);
//                    lm.push_back(landmark[j * 2 + c0 * width + c1] * s0 + y1);
//                }
//                lms.push_back(lm);
//            }

        }
    }

    std::vector<BBoxstr> bboxesNew = nms(modelInfo_.nmsThresh, detectResults, modelInfo_.classnum);
    return bboxesNew;

}



AclLiteError centerfacePostprocessThread::InferOutputProcess(std::shared_ptr<ObjDetectDataMsg> objDetectDataMsg)
{
    if (objDetectDataMsg->isLastFrame)
        return ACLLITE_OK;

    float* heatmap = (float *)objDetectDataMsg->detectInferData[0].data.get();
    float* scale = (float *)objDetectDataMsg->detectInferData[1].data.get();
    float* offset = (float *)objDetectDataMsg->detectInferData[2].data.get();
    float* landmark = (float *)objDetectDataMsg->detectInferData[3].data.get();
    if (heatmap == nullptr) {
        ACLLITE_LOG_ERROR("detect inferoutput is null\n");
        return ACLLITE_ERROR;
    }

    totalBox_ = modelInfo_.totalBox;
    std::vector<std::vector<float>> lms;
    std::vector<BBoxstr> bboxesNew = outPutDecode(heatmap, scale,  offset, landmark, kwidth/4, kheight/4,lms,objDetectDataMsg->imageFrame.width,objDetectDataMsg->imageFrame.height);
    objDetectDataMsg->objInfo = static_cast<const std::vector<ObjInfo>>(NULL);
    for (auto& bboxesNew_i : bboxesNew)
    {
        ObjInfo objInfo;
        objInfo.rectangle.lt.x = bboxesNew_i.rect.ltX;
        objInfo.rectangle.lt.y = bboxesNew_i.rect.ltY;
        objInfo.rectangle.rb.x = bboxesNew_i.rect.rbX;
        objInfo.rectangle.rb.y = bboxesNew_i.rect.rbY;

        auto constr = std::to_string(bboxesNew_i.score);
        constr =  constr.substr(0, constr.find(".") + 3);
        objInfo.detect_result = modelInfo_.Label[bboxesNew_i.cls]  + '_'+ constr;
        objDetectDataMsg->objInfo.emplace_back(objInfo);


    }
    return ACLLITE_OK;
}



AclLiteError centerfacePostprocessThread::Init() {
    return DetectPostprocessThread::Init();
}

AclLiteError centerfacePostprocessThread::Process(int msgId, std::shared_ptr<void> data) {
    return DetectPostprocessThread::Process(msgId, data);
}

centerfacePostprocessThread::centerfacePostprocessThread(const char *&configFile, int channelId, ModelInfo programinfo)
        : DetectPostprocessThread(configFile, channelId, programinfo) {

}

centerfacePostprocessThread::~centerfacePostprocessThread() {

}

std::vector<BBoxstr> centerfacePostprocessThread::nms(const float nmsThresh, std::vector<BBoxstr> &binfo, const uint numClasses) {
    return DetectPostprocessThread::nmsAllClasses(nmsThresh, binfo, numClasses);
}

//std::vector<BBoxstr> centerfacePostprocessThread::nonMaximumSuppression(const float nmsThresh, std::vector<BBoxstr> binfo) {
//    return DetectPostprocessThread::nonMaximumSuppression(nmsThresh, binfo);
//}

//std::vector<BBoxstr> centerfacePostprocessThread::nonMaximumSuppression(const float nmsThresh, std::vector<BBoxstr> binfo) {
//    return DetectPostprocessThread::nonMaximumSuppression(nmsThresh, binfo);
//}
