package com.censoft.flink.mqtt;

import com.alibaba.fastjson2.JSON;
import com.censoft.flink.domain.AlgorithmPushDto;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.eclipse.paho.client.mqttv3.*;
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence;

import java.io.ByteArrayOutputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;


/**
 * MQTT客户端订阅消息类
 *
 * @author zhongyulin
 */

public class MqttConsumer extends RichParallelSourceFunction<AlgorithmPushDto> {
    //存储服务
    private static MqttClient client;
    //存储订阅主题
    private static MqttTopic mqttTopic;
    //阻塞队列存储订阅的消息
    private BlockingQueue<AlgorithmPushDto> queue = new ArrayBlockingQueue(10);
    //mqtt对应频道
    private String msgTopic;

    //包装连接的方法
    private void connect() throws MqttException {
        //配置连接参数
        MqttConfig mqttConfigBean = new MqttConfig("", "", "tcp://127.0.0.1:1883", "DC" + (int) (Math.random() * 100000000), msgTopic);
        //连接mqtt服务器
        client = new MqttClient(mqttConfigBean.getHostUrl(), mqttConfigBean.getClientId(), new MemoryPersistence());
        MqttConnectOptions options = new MqttConnectOptions();
        options.setCleanSession(false);
        options.setUserName(mqttConfigBean.getUsername());
        options.setPassword(mqttConfigBean.getPassword().toCharArray());
        options.setCleanSession(false);   //是否清除session
        // 设置超时时间
        options.setConnectionTimeout(30);
        // 设置会话心跳时间
        options.setKeepAliveInterval(20);
        try {
            String[] msgtopic = mqttConfigBean.getMsgTopic();
            //订阅消息
            int[] qos = new int[msgtopic.length];
            for (int i = 0; i < msgtopic.length; i++) {
                qos[i] = 0;
            }
            client.setCallback(new MsgCallback(client, options, msgtopic, qos) {
            });
            client.connect(options);
            client.subscribe(msgtopic, qos);
            System.out.println("MQTT连接成功:" + mqttConfigBean.getClientId() + ":" + client);
        } catch (Exception e) {
            System.out.println("MQTT连接异常：" + e);
            e.printStackTrace();
        }
    }

    //实现MqttCallback，内部函数可回调
    class MsgCallback implements MqttCallback {
        private MqttClient client;
        private MqttConnectOptions options;
        private String[] topic;
        private int[] qos;

        public MsgCallback() {
        }

        public MsgCallback(MqttClient client, MqttConnectOptions options, String[] topic, int[] qos) {
            this.client = client;
            this.options = options;
            this.topic = topic;
            this.qos = qos;
        }

        //连接失败回调该函数
        @Override
        public void connectionLost(Throwable throwable) {
            throwable.printStackTrace();
            System.out.println("MQTT连接断开，发起重连");
            while (true) {
                try {
                    Thread.sleep(1000);
                    client.connect(options);
                    //订阅消息
                    client.subscribe(topic, qos);
                    System.out.println("MQTT重新连接成功:" + client);
                    break;
                } catch (Exception e) {
                    e.printStackTrace();
                    continue;
                }
            }
        }

        //收到消息回调该函数
        @Override
        public void messageArrived(String s, MqttMessage message) throws Exception {
            //订阅消息字符
            String msg = new String(message.getPayload());
            byte[] bymsg = getBytesFromObject(msg);
            AlgorithmPushDto algorithmPushDto = JSON.parseObject(msg, AlgorithmPushDto.class);

            if (algorithmPushDto != null) {
                queue.put(algorithmPushDto);
            }

        }

        //对象转化为字节码
        public byte[] getBytesFromObject(Serializable obj) throws Exception {
            if (obj == null) {
                return null;
            }
            ByteArrayOutputStream bo = new ByteArrayOutputStream();
            ObjectOutputStream oo = new ObjectOutputStream(bo);
            oo.writeObject(obj);
            return bo.toByteArray();
        }

        @Override
        public void deliveryComplete(IMqttDeliveryToken iMqttDeliveryToken) {

        }
    }

    //flink线程启动函数
    @Override
    public void run(final SourceContext<AlgorithmPushDto> ctx) throws Exception {
        connect();
        //利用死循环使得程序一直监控主题是否有新消息
        while (true) {
            //使用阻塞队列的好处是队列空的时候程序会一直阻塞到这里不会浪费CPU资源
            ctx.collect(queue.take());
        }
    }

    @Override
    public void cancel() {

    }

    /**
     * 订阅某个主题
     *
     * @param topic
     * @param qos
     */
    public void subscribe(String topic, int qos) {
        try {
            System.out.println("topic:" + topic);
            client.subscribe(topic, qos);
        } catch (MqttException e) {
            e.printStackTrace();
        }
    }

    public MqttClient getClient() {
        return client;
    }

    public void setClient(MqttClient client) {
        this.client = client;
    }

    public MqttTopic getMqttTopic() {
        return mqttTopic;
    }

    public void setMqttTopic(MqttTopic mqttTopic) {
        this.mqttTopic = mqttTopic;
    }

    public MqttConsumer() {
    }

    public MqttConsumer(String msgTopic) {
        this.msgTopic = msgTopic;
    }
}