package com.xunlei.stream.util.ip.mysql;

import com.xunlei.stream.util.PropertiesUtil;
import com.xunlei.stream.util.ip.IpAreaApi;
import com.xunlei.stream.util.ip.IpAreaInfo;
import com.xunlei.common.Assert;
import com.xunlei.common.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;

/**
 * @author <a href="http://xiongyingqi.com">qi</a>
 * @version 2015-11-02 20:48
 */
public class MySQLIpAreaApi implements IpAreaApi {
    private static final Logger                logger    = LoggerFactory
            .getLogger(MySQLIpAreaApi.class);
    public static final  ArrayList<IpAreaInfo> vecIpInfo = new ArrayList<IpAreaInfo>();

    static {
        PropertiesUtil propertiesUtil = new PropertiesUtil("jdbc.properties");
        String driverClassName = propertiesUtil.getProperty("jdbc.driverClassName");
        String url = propertiesUtil.getProperty("jdbc.url");
        String username = propertiesUtil.getProperty("jdbc.username");
        String password = propertiesUtil.getProperty("jdbc.password");
        init(driverClassName, url, username, password);
    }

    public MySQLIpAreaApi() {

    }

    public static void init(String driverClassName, String url, String username, String password) {
        try {
            Class.forName(driverClassName);//指定连接类型
            Connection conn = DriverManager.getConnection(url, username, password);//获取连接
            PreparedStatement pst = conn
                    .prepareStatement("SELECT * FROM ip_db.ip_city_cn order by start_ip");//准备执行语句
            ResultSet ret = pst.executeQuery();
//            ret.last();
//            logger.info("init... resultSize: {}", ret.getRow());
//            ret.first();

            int size = 0;
            while (ret.next()) {
                IpAreaInfo ipInfo = new IpAreaInfo();
                ipInfo.setCity(ret.getString("city"));
                ipInfo.setProvince(ret.getString("province"));
                ipInfo.setStart(ret.getLong("start_ip"));
                ipInfo.setEnd(ret.getLong("end_ip"));
                ipInfo.setKey(ret.getInt("id"));
                vecIpInfo.add(ipInfo);
                size++;
                //vecIpInfo.set(i, ipInfo);
            }
            logger.info("initialization ipInfo size: {}", size);

            ret.close();
            pst.close();
            conn.close();
        } catch (Exception e) {
            logger.error("", e);
        }
    }

    public static long ipToLong(String ip) throws Exception {
        Assert.hasText(ip, "ip must no be null!");
        String[] sections = ip.trim().split("\\.");
        Assert.notEmpty(sections, "illegal ip format!");
        Assert.equals(sections.length, 4, "illegal ip format!");
        long rs = 0;
        for (String section : sections) {
            rs = (rs << 8) + Integer.parseInt(section);
        }
        return rs;
    }


    /**
     * 获取ip对应的区域信息
     *
     * @param ip ip地址
     * @return IpAreaInfo
     */
    @Override
    public IpAreaInfo getIpInfo(String ip) {
        if (!StringUtils.hasText(ip)) {
            return null;
        }
        Long iIp = null;
        try {
            iIp = ipToLong(ip);
        } catch (Exception e) {
            logger.error("", e);
            return null;
        }
        int size = vecIpInfo.size() - 1;
        int begin = 0;
        int end = size;
        int midd = 0;

        if (logger.isDebugEnabled()) {
            logger.debug(
                    "begin = {}, end = {}, size = {}, vecIpInfo.get(begin).start = {}, vecIpInfo.get(end).getEnd() = {},  vecIpInfo.get(begin).getKey() = {}, iIp = {}",
                    begin, end, size, vecIpInfo.get(begin).getStart(),
                    vecIpInfo.get(end).getEnd(), vecIpInfo.get(begin).getKey(), iIp);
        }
        if (vecIpInfo.get(begin).getStart() > iIp || vecIpInfo.get(end).getEnd() < iIp) {
            //logger.warn("ip: {}, could not found area info!", ip);
            return null;
        }

        while (begin <= end) {
            midd = (begin + end) / 2;

            IpAreaInfo midIpAreaInfo = vecIpInfo.get(midd);
            if (midIpAreaInfo == null) {
                logger.error("getIpInfo... IpAreaInfo is null when get mid: {}", midd);
                return null;
            }

            if (iIp.equals(midIpAreaInfo.getStart()) || iIp.equals(midIpAreaInfo.getEnd())) {
                return midIpAreaInfo;
            } else if (iIp > midIpAreaInfo.getStart()) {
                begin = midd + 1;
            } else {
                end = midd - 1;
            }
        }

        IpAreaInfo beginIpInfo = vecIpInfo.get(begin);
        IpAreaInfo endIpInfo = vecIpInfo.get(end);
        if (endIpInfo.getStart() <= iIp && endIpInfo.getEnd() >= iIp) {
            return endIpInfo;
        } else if (beginIpInfo.getStart() <= iIp && beginIpInfo.getEnd() >= iIp) {
            return beginIpInfo;
        } else {
//            logger.warn("ip: {}, could not found area info!", ip);
            return null;
        }
    }

    /**
     * 获取ip对应的区域id
     *
     * @param ip ip地址
     * @return 查询成功返回id，失败则返回-1
     */
    @Override
    public int ipToAreaId(String ip) {
        IpAreaInfo ipInfo = getIpInfo(ip);
        if (ipInfo == null) {
            return -1;
        }
        return ipInfo.getKey();
    }
}
