package com.xunlei.stream.util.ip.redis;

import com.xunlei.stream.util.IpUtil;
import com.xunlei.stream.util.PropertiesUtil;
import com.xunlei.stream.util.ip.IllegalIpAreaProviderException;
import com.xunlei.stream.util.ip.IpAreaApi;
import com.xunlei.stream.util.ip.IpAreaFactory;
import com.xunlei.stream.util.ip.IpAreaInfo;
import com.xunlei.stream.util.ip.mysql.MySQLIpAreaProvider;
import com.xunlei.stream.util.redis.RedisApi;
import com.xunlei.stream.util.redis.RedisFactory;
import com.xunlei.common.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.Set;

/**
 * 使用redis来存储ip地址信息，这样不需要耗java内存，并且不需要自己处理ip范围的问题<p></p>
 * <h1>关于存储</h1>
 * 使用redis的<a href="https://redis.readthedocs.org/en/2.4/sorted_set.html#zadd">zadd</a>指令来存储ip区域信息，
 * start作为score，end->areaId作为member。redis会根据<b>score</b>来进行增序来排序
 * <h1>查找</h1>
 * 假设ip地址转换为i，根据score来查询（<a href="https://redis.readthedocs.org/en/2.4/sorted_set.html#zrevrangebyscore">ZREVRANGEBYSCORE</a>命令）来查找小于该分数的最相近的记录
 *
 * @author xiongyingqi
 * @version 2015-12-28 16:03
 */
public class RedisIpAreaApi implements IpAreaApi {
    public static final  String SPLIT             = "->";
    public static final  String IP_AREA_REDIS_KEY = "ip_area_redis_key";
    private static final Logger logger            = LoggerFactory.getLogger(RedisIpAreaApi.class);

    static {
        PropertiesUtil propertiesUtil = new PropertiesUtil("jdbc.properties");
        String driverClassName = propertiesUtil.getProperty("jdbc.driverClassName");
        String url = propertiesUtil.getProperty("jdbc.url");
        String username = propertiesUtil.getProperty("jdbc.username");
        String password = propertiesUtil.getProperty("jdbc.password");
        init(driverClassName, url, username, password);
    }

    public RedisIpAreaApi() {

    }

    public static void init(String driverClassName, String url, String username, String password) {
        try {
            Class.forName(driverClassName);//指定连接类型
            Connection conn = DriverManager.getConnection(url, username, password);//获取连接
            PreparedStatement pst = conn
                    .prepareStatement("SELECT * FROM ip_db.ip_city_cn order by start_ip");//准备执行语句
            ResultSet ret = pst.executeQuery();
//            ret.last();
//            logger.info("init... resultSize: {}", ret.getRow());
//            ret.first();

            RedisApi redisApi = RedisFactory.newRedisApi("RedisIpAreaApi");
            int size = 0;
            while (ret.next()) {
                IpAreaInfo ipInfo = new IpAreaInfo();
                ipInfo.setCity(ret.getString("city"));
                ipInfo.setProvince(ret.getString("province"));
                ipInfo.setStart(ret.getLong("start_ip"));
                ipInfo.setEnd(ret.getLong("end_ip"));
                ipInfo.setKey(ret.getInt("id"));
                addKey(redisApi, ipInfo);
                size++;
                //vecIpInfo.set(i, ipInfo);
            }

            logger.info("init... ipInfo size: {}", size);
            ret.close();
            pst.close();
            conn.close();
        } catch (Exception e) {
            logger.error("", e);
        }
    }

    /**
     * 将ip地址信息存储到redis<p></p>
     * 将start放到score内，end放到value内
     *
     * @param redisApi
     * @param ipInfo
     */
    private static void addKey(RedisApi redisApi, IpAreaInfo ipInfo) {
        redisApi.zadd(IP_AREA_REDIS_KEY, ipInfo.getStart(),
                ipInfo.getEnd() + SPLIT + ipInfo.getKey());
    }

    /**
     * 获取ip对应的区域信息
     *
     * @param ip ip地址
     * @return IpAreaInfo
     */
    @Override
    public IpAreaInfo getIpInfo(String ip) {
        throw new RuntimeException("redis的方式不支持该方法！");
    }

    /**
     * 获取ip对应的区域id
     *
     * @param ip ip地址
     * @return 查询成功返回id，失败则返回-1
     */
    @Override
    public int ipToAreaId(String ip) {
        try {
            int areaId = findAreaId(ip);

            return areaId;
        } catch (Exception e) {
            logger.error("", e);
        }
        return 0;
    }

    /**
     * 使用redis来查询ip所在区域，由于在redis内
     *
     * @param ip
     * @return
     */
    public int findAreaId(String ip) {
        RedisApi redisApi = RedisFactory.newRedisApi("RedisIpAreaApi");
        try {
            long ipValue = IpUtil.parseToLong(ip);
            for (int i = 3; i > 1; i--) { // 目前设置最大截取两位ip来查找
                long section = IpUtil.parseSectionAndSubSection(ip, i);
                Set<String> elements = redisApi
                        .zrevrangeByScore(IP_AREA_REDIS_KEY, ipValue, section, 0, 1); // 使用倒序排列，只取一个
                Integer[] result = findResult(ipValue, elements);
                if (result == null || result.length == 0) {
                    continue;
                }
                return result[0];
            }
        } catch (Exception e) {
            logger.error("find ip: " + ip + " with error: " + e.getMessage(), e);
        }
        return -1;

    }

    /**
     * 先解析elements内的所有元素，然后查找ipValue是不是在elements内
     *
     * @param ipValue
     * @param elements
     * @return
     */
    public Integer[] findResult(Long ipValue, Set<String> elements) {
        if (elements == null || elements.isEmpty()) {
            return null;
        }

        ArrayList<Integer> areaIdList = new ArrayList<Integer>();
        for (String element : elements) {
            Integer areaId = matchArea(ipValue, element);
            if (areaId != null) {
                areaIdList.add(areaId);
            }
        }

        if (areaIdList.isEmpty()) {
            return null;
        }

        return areaIdList.toArray(new Integer[] {});
    }

    /**
     * 根据element来判断ipValue地址是否小于等于maxIpValue，如果小于等于则返回AreaId
     *
     * @param ipValue
     * @param element
     * @return
     */
    private Integer matchArea(Long ipValue, String element) {
        String[] maxIpAndAreaId = parseElement(element);
        String maxIp = maxIpAndAreaId[0];
        Long maxIpValue = Long.parseLong(maxIp);
        String areaId = maxIpAndAreaId[1];
        Integer areaIdValue = Integer.parseInt(areaId);
        if (maxIpValue >= ipValue) { // 如果ip在maxIp范围内，则返回此ipValue
            return areaIdValue;
        }
        return null;
    }

    /**
     * 解析zadd结果集的element
     *
     * @param element zadd的结果集
     * @return 返回最大ip值和areaId
     */
    private String[] parseElement(String element) {
        Assert.hasText(element);
        return element.split(SPLIT);
    }

    public static void main(String[] args) throws IllegalIpAreaProviderException {
        IpAreaFactory.setIpAreaProvider(new MySQLIpAreaProvider());
        IpAreaApi ipAreaApi = IpAreaFactory.newApi();
        for (int j = 0; j < 1000; j++) {
            int i = ipAreaApi.ipToAreaId("221.192.168.1");//3
            System.out.println(i);//3
            int i2 = ipAreaApi.ipToAreaId("211.98.181.65");//441
            System.out.println(i2); // 441
        }
    }

}
