package com.xunyi.beast.security.access.predicate;

import com.google.common.collect.Lists;
import com.xunyi.beast.security.access.AccessContext;
import io.netty.handler.ipfilter.IpFilterRuleType;
import io.netty.handler.ipfilter.IpSubnetFilterRule;
import lombok.Getter;
import lombok.Setter;
import org.springframework.validation.annotation.Validated;

import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Predicate;

public class RemoteAddrPredicateFactory extends AbstractPredicateFactory<RemoteAddrPredicateFactory.Config>{



    @Override
    public Predicate<AccessContext> apply(Config config) {
        List<IpSubnetFilterRule> sources = convert(config.sources);
        return new Predicate<AccessContext>() {
            @Override
            public boolean test(AccessContext context) {
                InetSocketAddress remoteAddress = context.getRemoteAddress();
                if (remoteAddress != null) {
                    for (IpSubnetFilterRule source : sources) {
                        if (source.matches(remoteAddress)) {
                            return true;
                        }
                    }
                }
                return false;
            }
        };
    }

    @NotNull
    private List<IpSubnetFilterRule> convert(List<String> values) {
        List<IpSubnetFilterRule> sources = new ArrayList<>();
        for (String arg : values) {
            addSource(sources, arg);
        }
        return sources;
    }

    private void addSource(List<IpSubnetFilterRule> sources, String source) {
        if (!source.contains("/")) { // no netmask, add default
            source = source + "/32";
        }

        String[] ipAddressCidrPrefix = source.split("/",2);
        String ipAddress = ipAddressCidrPrefix[0];
        int cidrPrefix = Integer.parseInt(ipAddressCidrPrefix[1]);

        sources.add(new IpSubnetFilterRule(ipAddress, cidrPrefix, IpFilterRuleType.ACCEPT));
    }

    @Getter
    @Setter
    @Validated
    public static class Config {

        @NotEmpty
        private List<String> sources = Lists.newArrayList();
    }
}
