package org.beast.cloud.grpc;

import com.google.common.collect.Maps;
import io.grpc.Channel;
import org.springframework.aop.framework.Advised;
import org.springframework.aop.support.AopUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.PropertyValues;
import org.springframework.beans.factory.annotation.Autowired;
//import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessorAdapter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.util.ReflectionUtils;

import java.beans.PropertyDescriptor;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

//TODO extends InstantiationAwareBeanPostProcessorAdapter
public class GRPCClientBeanPostProcessor  {

    //保留待处理的 避免被其他aop替换
    private Map<String, List<Class>> beansToProcess = Maps.newHashMap();

    @Autowired
    private GRPCChannelFactory channelFactory;

//    @Override
//    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
//        Class clazz = bean.getClass();
//        do {
//            for (Field field : clazz.getDeclaredFields()) {
//                if (field.isAnnotationPresent(GRPCClient.class)) {
//                    if (!beansToProcess.containsKey(beanName)) {
//                        beansToProcess.put(beanName, new ArrayList<Class>());
//                    }
//                    beansToProcess.get(beanName).add(clazz);
//                }
//            }
//            clazz = clazz.getSuperclass();
//        } while (clazz != null);
//        return bean;
//    }
//
//    @Override
//    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
//        if (beansToProcess.containsKey(beanName)) {
//            try {
//                Object target = getTargetBean(bean);
//                for (Class clazz : beansToProcess.get(beanName)) {
//                    for (Field field : clazz.getDeclaredFields()) {
//                        GRPCClient annotation = AnnotationUtils.getAnnotation(field, GRPCClient.class);
//                        if (Objects.nonNull(annotation)) {
//                            Channel channel = channelFactory.createChannel(annotation.value());
//                            ReflectionUtils.makeAccessible(field);
//                            ReflectionUtils.setField(field, target, channel);
//                        }
//                    }
//                }
//            } catch (Exception e) {
//                e.printStackTrace();
//            }
//        }
//        return bean;
//    }

//    @Override
    public Object postProcessBeforeInstantiation(Class<?> beanClass, String beanName) throws BeansException {
        Class clazz = beanClass;
        do {
            for (Field field : clazz.getDeclaredFields()) {
                if (field.isAnnotationPresent(GRPCClient.class)) {
                    if (!beansToProcess.containsKey(beanName)) {
                        beansToProcess.put(beanName, new ArrayList<Class>());
                    }
                    beansToProcess.get(beanName).add(clazz);
                }
            }
            clazz = clazz.getSuperclass();
        } while (clazz != null);
        return null;
    }

//    @Override
    public PropertyValues postProcessPropertyValues(
            PropertyValues pvs, PropertyDescriptor[] pds, Object bean, String beanName) throws BeansException {
        if (beansToProcess.containsKey(beanName)) {
            try {
                Object target = getTargetBean(bean);
                for (Class clazz : beansToProcess.get(beanName)) {
                    for (Field field : clazz.getDeclaredFields()) {
                        GRPCClient annotation = AnnotationUtils.getAnnotation(field, GRPCClient.class);
                        if (Objects.nonNull(annotation)) {
                            Channel channel = channelFactory.createChannel(annotation.value());
                            ReflectionUtils.makeAccessible(field);
                            ReflectionUtils.setField(field, target, channel);
                        }
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return pvs;
    }

    private Object getTargetBean(Object bean) throws Exception {
        Object target = bean;
        while (AopUtils.isAopProxy(target)) {
            target = ((Advised) target).getTargetSource().getTarget();
        }
        return target;
    }

}
