package org.beast.data.querydsl;

import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Predicate;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.core.convert.ConversionService;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.data.querydsl.binding.*;
import org.springframework.data.util.CastUtils;
import org.springframework.data.util.ClassTypeInformation;
import org.springframework.data.util.TypeInformation;
import org.springframework.data.web.querydsl.QuerydslPredicateArgumentResolver;
import org.springframework.lang.Nullable;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.ModelAndViewContainer;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;

public class CustomQuerydslPredicateArgumentResolver extends QuerydslPredicateArgumentResolver {

    private static final ResolvableType PREDICATE = ResolvableType.forClass(Predicate.class);
    private static final ResolvableType OPTIONAL_OF_PREDICATE = ResolvableType.forClassWithGenerics(Optional.class,
            PREDICATE);

    private final QuerydslBindingsFactory bindingsFactory;
    private final QuerydslPredicateBuilder predicateBuilder;

    /**
     * Creates a new {@link QuerydslPredicateArgumentResolver} using the given {@link ConversionService}.
     *
     * @param factory
     * @param conversionService defaults to {@link DefaultConversionService} if {@literal null}.
     */
    public CustomQuerydslPredicateArgumentResolver(QuerydslBindingsFactory factory,
                                                   Optional<ConversionService> conversionService) {
        super(factory, conversionService);
        this.bindingsFactory = factory;
        //改造
        this.predicateBuilder = new CustomQuerydslPredicateBuilder(conversionService.orElseGet(DefaultConversionService::new),
                factory.getEntityPathResolver());
    }


    /*
     * (non-Javadoc)
     * @see org.springframework.web.method.support.HandlerMethodArgumentResolver#resolveArgument(org.springframework.core.MethodParameter, org.springframework.web.method.support.ModelAndViewContainer, org.springframework.web.context.request.NativeWebRequest, org.springframework.web.bind.support.WebDataBinderFactory)
     */
    @Nullable
    @Override
    public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer,
                                  NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception {

        MultiValueMap<String, String> queryParameters = getQueryParameters(webRequest);


        Optional<QuerydslPredicate> annotation = Optional
                .ofNullable(parameter.getParameterAnnotation(QuerydslPredicate.class));
        TypeInformation<?> domainType = extractTypeInfo(parameter).getRequiredActualType();

        Optional<Class<? extends QuerydslBinderCustomizer<?>>> bindingsAnnotation = annotation //
                .map(QuerydslPredicate::bindings) //
                .map(CastUtils::cast);

        QuerydslBindings bindings = bindingsAnnotation //
                .map(it -> bindingsFactory.createBindingsFor(domainType, it)) //
                .orElseGet(() -> bindingsFactory.createBindingsFor(domainType));
        //TODO 改造
        Predicate result = predicateBuilder.getPredicate(domainType, queryParameters, bindings);

        if (!parameter.isOptional() && result == null) {
            return new BooleanBuilder();
        }

        return OPTIONAL_OF_PREDICATE.isAssignableFrom(ResolvableType.forMethodParameter(parameter)) //
                ? Optional.ofNullable(result) //
                : result;
    }

    private static MultiValueMap<String, String> getQueryParameters(NativeWebRequest webRequest) {

        Map<String, String[]> parameterMap = webRequest.getParameterMap();
        MultiValueMap<String, String> queryParameters = new LinkedMultiValueMap<>(parameterMap.size());

        for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
            queryParameters.put(entry.getKey(), Arrays.asList(entry.getValue()));
        }

        return queryParameters;
    }

    /**
     * Obtains the domain type information from the given method parameter. Will favor an explicitly registered on through
     * {@link QuerydslPredicate#root()} but use the actual type of the method's return type as fallback.
     *
     * @param parameter must not be {@literal null}.
     * @return
     */
    static TypeInformation<?> extractTypeInfo(MethodParameter parameter) {

        Optional<QuerydslPredicate> annotation = Optional
                .ofNullable(parameter.getParameterAnnotation(QuerydslPredicate.class));

        return annotation.filter(it -> !Object.class.equals(it.root()))//
                .<TypeInformation<?>> map(it -> ClassTypeInformation.from(it.root()))//
                .orElseGet(() -> detectDomainType(parameter));
    }

    private static TypeInformation<?> detectDomainType(MethodParameter parameter) {

        Method method = parameter.getMethod();

        if (method == null) {
            throw new IllegalArgumentException("Method parameter is not backed by a method!");
        }

        return detectDomainType(ClassTypeInformation.fromReturnTypeOf(method));
    }

    private static TypeInformation<?> detectDomainType(TypeInformation<?> source) {

        if (source.getTypeArguments().isEmpty()) {
            return source;
        }

        TypeInformation<?> actualType = source.getActualType();

        if (actualType == null) {
            throw new IllegalArgumentException(String.format("Could not determine domain type from %s!", source));
        }

        if (source != actualType) {
            return detectDomainType(actualType);
        }

        if (source instanceof Iterable) {
            return source;
        }

        return detectDomainType(source.getRequiredComponentType());
    }
}
