package org.beast.graphql.querydsl;

import org.beast.data.querydsl.CursorEncoder;
import org.beast.data.querydsl.CursorPredicateFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.mongodb.repository.query.MongoEntityInformation;
import org.springframework.data.mongodb.repository.support.MappingMongoEntityInformation;
import org.springframework.data.mongodb.repository.support.MongoRepositoryFactory;
import org.springframework.data.mongodb.repository.support.MongoRepositoryFactoryBean;
import org.springframework.data.repository.Repository;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryComposition;
import org.springframework.data.repository.core.support.RepositoryFactorySupport;
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;

import java.io.Serializable;

import static org.springframework.data.querydsl.QuerydslUtils.QUERY_DSL_PRESENT;

public class CustomMongoRepositoryFactoryBean<T extends Repository<S, ID>, S, ID extends Serializable> extends MongoRepositoryFactoryBean<T, S, ID> {

    private CursorPredicateFactory cursorPredicateFactory;
    private CursorEncoder cursorEncoder;


    public CustomMongoRepositoryFactoryBean(
            Class<? extends T> repositoryInterface
    ) {
        super(repositoryInterface);

    }


    @Autowired
    public void setCursorSortService(CursorPredicateFactory cursorPredicateFactory) {
        this.cursorPredicateFactory = cursorPredicateFactory;
    }

    @Autowired
    public void setCursorEncoder(CursorEncoder cursorEncoder) {
        this.cursorEncoder = cursorEncoder;
    }

    @NonNull
    @Override
    protected RepositoryFactorySupport createRepositoryFactory() {
        return super.createRepositoryFactory();
    }

    @NonNull
    protected RepositoryFactorySupport getFactoryInstance(@NonNull MongoOperations operations) {

        MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext = operations.getConverter().getMappingContext();
        return new MongoRepositoryFactory(operations) {

            public <T, ID> MongoEntityInformation<T, ID> getEntityInformation(Class<T> domainClass,
                                                                               @Nullable RepositoryMetadata metadata) {

                MongoPersistentEntity<T> entity = (MongoPersistentEntity<T>) mappingContext.getRequiredPersistentEntity(domainClass);

                Class<ID> idType = metadata != null ? (Class<ID>) metadata.getIdType() : null;
                return new MappingMongoEntityInformation<T, ID>(entity, idType);
            }
            @NonNull
            @Override
            protected RepositoryComposition.RepositoryFragments getRepositoryFragments(@NonNull RepositoryMetadata metadata) {
                RepositoryComposition.RepositoryFragments fragments = super.getRepositoryFragments(metadata);
                boolean isGraphqlMongoQueryDslRepository = QUERY_DSL_PRESENT
                        && GraphqlQuerydslPredicateExecutor.class.isAssignableFrom(metadata.getRepositoryInterface());
                if (isGraphqlMongoQueryDslRepository) {

                    if (metadata.isReactiveRepository()) {
                        throw new InvalidDataAccessApiUsageException(
                                "Cannot combine Querydsl and reactive repository support in a single interface");
                    }

                    MongoEntityInformation<?, Serializable> entityInformation = getEntityInformation(metadata.getDomainType(),
                            metadata);

                    fragments = fragments.append(RepositoryFragment.implemented(
                            instantiateClass(GraphqlQuerydslMongoPredicateExecutor.class,
                                    entityInformation, operations, cursorPredicateFactory, cursorEncoder)));
                }

                return fragments;
            }
        };
    }


}
