package com.xunyi.micro.shunt.web.flux;

import com.xunyi.micro.propagation.Propagation;
import com.xunyi.micro.propagation.context.Extractor;
import com.xunyi.micro.shunt.Shunt;
import com.xunyi.micro.shunt.propagation.CurrentShuntContext;
import com.xunyi.micro.shunt.propagation.ShuntContext;
import lombok.extern.slf4j.Slf4j;
import org.reactivestreams.Subscription;
import org.springframework.core.Ordered;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.NonNull;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Mono;

import org.springframework.http.HttpHeaders;
import reactor.core.publisher.MonoOperator;
import reactor.util.context.Context;


@Slf4j
public class ShuntWebFilter implements WebFilter, Ordered {

    private CurrentShuntContext currentShuntContext;
    private Extractor<HttpHeaders> extractor;

    public ShuntWebFilter(Shunt shunt) {
        this.currentShuntContext = shunt.currentShuntContext();
        this.extractor = shunt.propagation().extractor(GETTER);
    }


    @Override
    public @NonNull Mono<Void> filter(@NonNull ServerWebExchange exchange,@NonNull WebFilterChain chain) {
        return new MonoWebFilterShunt(chain.filter(exchange), exchange);
    }

    private class MonoWebFilterShunt extends MonoOperator<Void, Void> {

        final ServerWebExchange exchange;

        protected MonoWebFilterShunt(Mono<? extends Void> source, ServerWebExchange exchange) {
            super(source);
            this.exchange = exchange;
        }

        @Override
        public void subscribe(@NonNull CoreSubscriber<? super Void> subscriber) {
            Context context = subscriber.currentContext();
            this.source.subscribe(new WebFilterShuntSubscriber(subscriber, context, findOrCreateScope(context)));
        }

        private CurrentShuntContext.Scope findOrCreateScope(Context c) {
            CurrentShuntContext.Scope scope;
            Object key = CurrentShuntContext.Scope.class;
            if (c.hasKey(key)) {
                scope = c.get(key);
            } else {
                ServerHttpRequest request = exchange.getRequest();
                ShuntContext context = ShuntWebFilter.this.extractor.extract(request.getHeaders());
                scope = ShuntWebFilter.this.currentShuntContext.newScope(context);
            }
            return scope;
        }

        final class WebFilterShuntSubscriber implements CoreSubscriber<Void> {


            final CoreSubscriber<? super Void> actual;
            final Context context;
            final CurrentShuntContext.Scope scope;


            public WebFilterShuntSubscriber(CoreSubscriber<? super Void> actual, Context context, CurrentShuntContext.Scope scope) {
                this.actual = actual;
                this.context = context.put(CurrentShuntContext.Scope.class, scope);
                this.scope = scope;
            }

            @Override
            public void onSubscribe(Subscription subscription) {
                this.actual.onSubscribe(subscription);
            }

            @Override
            public void onNext(Void aVoid) {

            }

            @Override
            public void onError(Throwable t) {
                this.terminate();
                this.actual.onError(t);
            }

            @Override
            public void onComplete() {
                this.terminate();
                this.actual.onComplete();
            }

            @Override
            public Context currentContext() {
                return this.context;
            }

            private void terminate() {
                this.scope.close();
            }
        }
    }

    @Override
    public int getOrder() {
        return Ordered.HIGHEST_PRECEDENCE + 4;
    }


    public static final Propagation.Getter<HttpHeaders, String> GETTER = new Propagation.Getter<HttpHeaders, String>() {
        @Override
        public String get(HttpHeaders carrier, String key) {
            return carrier.getFirst(key);
        }
    };
}
