package org.beast.graphql.querydsl;

import com.google.common.collect.Iterables;
import com.querydsl.core.types.*;
import com.querydsl.core.types.dsl.PathBuilder;
import org.beast.data.querydsl.CursorEncoder;
import org.beast.data.querydsl.CursorPredicateFactory;
import org.beast.graphql.data.*;
import org.springframework.data.domain.Sort;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.repository.support.MappingMongoEntityInformation;
import org.springframework.data.mongodb.repository.support.SpringDataMongodbQuery;
import org.springframework.data.querydsl.EntityPathResolver;
import org.springframework.data.querydsl.QSort;
import org.springframework.data.querydsl.SimpleEntityPathResolver;
import org.springframework.data.repository.core.EntityInformation;
import org.springframework.data.util.ClassTypeInformation;
import org.springframework.data.util.TypeInformation;
import org.springframework.util.Assert;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

public class GraphqlQuerydslMongoPredicateExecutor<T> implements GraphqlQuerydslPredicateExecutor<T> {


    private final PathBuilder<T> builder;
    private final MappingMongoEntityInformation<T, ?> entityInformation;
    private final TypeInformation<T> typeInformation;
    private final MongoOperations mongoOperations;
    private CursorEncoder cursorEncoder;
    private final CursorPredicateFactory cursorPredicateFactory;

    public GraphqlQuerydslMongoPredicateExecutor(
            MappingMongoEntityInformation<T, ?> entityInformation,
            MongoOperations mongoOperations,
            CursorPredicateFactory cursorPredicateFactory,
            CursorEncoder cursorEncoder
    ) {
        this(entityInformation, mongoOperations, SimpleEntityPathResolver.INSTANCE, cursorPredicateFactory);
        this.cursorEncoder = cursorEncoder;
    }

    public GraphqlQuerydslMongoPredicateExecutor(
            MappingMongoEntityInformation<T, ?> entityInformation,
            MongoOperations mongoOperations,
            EntityPathResolver resolver,
            CursorPredicateFactory cursorPredicateFactory) {
        this.entityInformation = entityInformation;
        this.cursorPredicateFactory = cursorPredicateFactory;
        this.builder = pathBuilderFor(resolver.createPath(entityInformation.getJavaType()));
        this.mongoOperations = mongoOperations;
        this.typeInformation = ClassTypeInformation.from(this.entityInformation.getJavaType());
    }

    protected static <E> PathBuilder<E> pathBuilderFor(EntityPath<E> path) {
        return new PathBuilder<>(path.getType(), path.getMetadata());
    }


    @Override
    public Iterable<T> findAll(Predicate predicate, Pageable pageable) {
        return findAllEdge(predicate, pageable).stream()
                .map(Edge::getNode)
                .collect(Collectors.toList());
    }

    @Override
    public List<Edge<T>> findAllEdge(Predicate predicate, Pageable pageable) {
        Assert.notNull(predicate, "Predicate must not be null!");
        Assert.notNull(pageable, "Pageable must not be null!");

        SpringDataMongodbQuery<T> query = createQueryFor(predicate);
        Sort sort = pageable.getSort();
        return applyPagination(query, pageable).fetch().stream().map((node) -> {
            ConnectionCursor cursor = new ConnectionCursor(cursorEncoder.encode(typeInformation, sort, node));
            return new DefaultEdge<>(node, cursor);
        }).collect(Collectors.toList());
    }

    @Override
    public Edge<T> findNextEdge(Predicate predicate, ConnectionCursor after, Sort sort) {
        SpringDataMongodbQuery<T> query = createQueryFor(predicate);
        query = applyAfter(query, after, sort);
        T node = query.limit(1).fetchFirst();
        if (node == null) {
            return null;
        }
        ConnectionCursor cursor = new ConnectionCursor(cursorEncoder.encode(typeInformation, sort, node));
        return new DefaultEdge<>(node, cursor);
    }

    @Override
    public Edge<T> findPreviousEdge(Predicate predicate, ConnectionCursor before, Sort sort) {
        SpringDataMongodbQuery<T> query = createQueryFor(predicate);
        query = applyBefore(query, before, sort);
        T node = query.limit(1).fetchFirst();
        if (node == null) {
            return null;
        }
        ConnectionCursor cursor = new ConnectionCursor(cursorEncoder.encode(typeInformation, sort, node));
        return new DefaultEdge<>(node, cursor);
    }

    @Override
    public DefaultConnection<T> connect(Predicate predicate, Pageable pageable) {
        Sort sort = pageable.getSort();
        List<Edge<T>> edges = findAllEdge(predicate, pageable);
        var first = Iterables.getFirst(edges, null);
        var last = Iterables.getLast(edges, null);

        var hasPreviousPage = first != null && this.findPreviousEdge(predicate, first.getCursor(), sort) != null;
        var hasNextPage = last != null && this.findNextEdge(predicate, last.getCursor(), sort) != null;
        SpringDataMongodbQuery<T> query = createQueryFor(predicate);
        return new DefaultConnection<>(edges, new PageInfo(
                Optional.ofNullable(first).map(Edge::getCursor).orElse(null),
                Optional.ofNullable(last).map(Edge::getCursor).orElse(null),
                hasPreviousPage,
                hasNextPage
        ), query.fetchCount());
    }

    protected EntityInformation<T, ?> typeInformation() {
        return entityInformation;
    }

    private SpringDataMongodbQuery<T> createQueryFor(Predicate predicate) {
        return createQuery().where(predicate);
    }

    private SpringDataMongodbQuery<T> createQuery() {
        return new SpringDataMongodbQuery<T>(mongoOperations, typeInformation().getJavaType());
    }

    private SpringDataMongodbQuery<T> applyAfter(SpringDataMongodbQuery<T> query, ConnectionCursor cursor, Sort sort) {
        Predicate predicate = cursorPredicateFactory.after(typeInformation, cursor, sort);
        if (predicate != null) {
            query = query.where(predicate);
        }
        return query;
    }
    private SpringDataMongodbQuery<T> applyBefore(SpringDataMongodbQuery<T> query, ConnectionCursor cursor, Sort sort) {
        Predicate predicate = cursorPredicateFactory.before(typeInformation, cursor, sort);
        if (predicate != null) {
            query = query.where(predicate);
        }
        return query;
    }
    private SpringDataMongodbQuery<T> applyPagination(SpringDataMongodbQuery<T> query, Pageable pageable) {
        Optional<ConnectionCursor> afterOptional = pageable.getAfterOptional();
        Optional<ConnectionCursor> beforeOptional = pageable.getBeforeOptional();

        Sort sort = pageable.getSort();
        if (afterOptional.isPresent()) {
            query = applyAfter(query, afterOptional.get(), sort);
        }


        if (beforeOptional.isPresent()) {
            query = applyBefore(query, beforeOptional.get(), sort);
        }
        Optional<Integer> offsetOptional = pageable.getOffsetOptional();
        Optional<Integer> firstOptional = pageable.getFirstOptional();
        Optional<Integer> lastOptional = pageable.getLastOptional();
        if (lastOptional.isPresent()) {
            long count = query.fetchCount();
            int last = lastOptional.get();
            int offset = offsetOptional.orElse(0);
            query = query.offset(count - last + offset).limit(last);
        } else {
            offsetOptional.ifPresent(query::offset);
            firstOptional.ifPresent(query::limit);
        }
        return applySorting(query, pageable.getSort());
    }

//    private Predicate toCursor(Cursor cursor, Sort sort) {
//        sort.stream()
//        Ops ops = direction.isDescending() ? Ops.LT : Ops.GT;
//        Instant sortCursor = null;
//        return Expressions.booleanOperation(ops, query.createdAt, ConstantImpl.create(sortCursor))
//        .or(Expressions.booleanOperation(Ops.EQ, query.createdAt, ConstantImpl.create(sortCursor))
//                .and(Expressions.booleanOperation(ops, query.id, ConstantImpl.create(id)))
//        );
//    }
    private SpringDataMongodbQuery<T> applySorting(SpringDataMongodbQuery<T> query, Sort sort) {
        toOrderSpecifiers(sort).forEach(query::orderBy);
        return query;
    }
    protected List<OrderSpecifier<?>> toOrderSpecifiers(Sort sort) {

        if (sort instanceof QSort) {
            return ((QSort) sort).getOrderSpecifiers();
        }

        return sort.stream().map(this::toOrder).collect(Collectors.toList());
    }
    protected OrderSpecifier<?> toOrder(Sort.Order order) {

        Expression<Object> property = builder.get(order.getProperty());

        return new OrderSpecifier(
                order.isAscending() ? Order.ASC : Order.DESC, property);
    }
}

