package dst.ass3.messaging.impl;

import dst.ass3.messaging.*;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.amqp.rabbit.listener.MessageListenerContainer;
import org.springframework.amqp.rabbit.listener.SimpleMessageListenerContainer;
import org.springframework.context.Lifecycle;

import java.util.*;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

import static dst.ass3.messaging.Constants.TOPIC_EXCHANGE;
import static dst.ass3.messaging.Constants.WORK_QUEUES;
import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.junit.Assert.assertThat;

public class WorkloadMonitorTest {

    private static final Logger LOG = LoggerFactory.getLogger(RequestGatewayTest.class);

    @Rule
    public RabbitResource rabbit = new RabbitResource();

    @Rule
    public Timeout timeout = new Timeout(30, TimeUnit.SECONDS);

    private IMessagingFactory factory;
    private IQueueManager queueManager;
    private IRequestGateway requestGateway;
    private IWorkloadMonitor workloadMonitor;

    @Before
    public void setUp() throws Exception {
        factory = new MessagingFactory();
        queueManager = factory.createQueueManager();
        requestGateway = factory.createRequestGateway();

        queueManager.setUp();

        workloadMonitor = factory.createWorkloadMonitor();
    }

    @After
    public void tearDown() throws Exception {
        queueManager.tearDown();

        requestGateway.close();
        queueManager.close();
        factory.close();
    }

    @Test
    public void getRequestCount_returnsCorrectCount() throws Exception {
        try {
            Map<RequestType, Long> countForType = new HashMap<>();
            for (RequestType type : RequestType.values()) {
                countForType.put(type, ThreadLocalRandom.current().nextLong(10, 20 + 1));
                for (long i = 0; i < countForType.get(type); i++) {
                    UploadRequest request = new UploadRequest("id" + i, type,
                            "urn:material:" + type.name().toLowerCase() + i);
                    LOG.info("Sending request {}", request);
                    requestGateway.uploadRequest(request);
                }
            }

            // wait for the messages to be processed by rabbit
            Thread.sleep(2000);

            assertThat(workloadMonitor.getRequestCount(), equalTo(countForType));
        } finally {
            workloadMonitor.close();
        }
    }

    @Test
    public void multipleWorkloadMonitors_uniqueQueueForEachMonitor() throws Exception {
        try (IWorkloadMonitor workloadMonitor2 = factory.createWorkloadMonitor();
             IWorkloadMonitor workloadMonitor3 = factory.createWorkloadMonitor();) {
            long nonWorkQueues = rabbit.getManager().getQueues().stream().filter(q -> !Arrays.asList(WORK_QUEUES).contains(q.getName())).count();
            assertThat(nonWorkQueues, greaterThanOrEqualTo(3L));
        } finally {
            workloadMonitor.close();
        }
    }

    @Test
    public void getAverageProcessingTime_correctAverageTime() throws Exception {
        try {
            Map<RequestType, Double> avgTimes = new HashMap<>();
            for (RequestType type : RequestType.values()) {
                long count = ThreadLocalRandom.current().nextLong(15, 25);
                long typeTime = 0;
                for (long i = 0; i < count; i++) {
                    long requestTime = ThreadLocalRandom.current().nextLong(1000, 20000 + 1);
                    if (count - i <= 10) {
                        typeTime += requestTime;
                    }

                    String body = String.format("{\"requestId\": \"%s\", \"processingTime\": \"%d\"}", UUID.randomUUID(), requestTime);
                    LOG.info("Sending request {}", body);
                    rabbit.getClient().convertAndSend(TOPIC_EXCHANGE, "requests." + type.toString().toLowerCase(), body);
                }
                avgTimes.put(type, ((double)typeTime / 10));
            }

            // wait for the messages to be processed by rabbit
            Thread.sleep(2000);

            assertThat(workloadMonitor.getAverageProcessingTime(), equalTo(avgTimes));
        } finally {
            workloadMonitor.close();
        }
    }

    @Test
    public void getWorkerCount_returnsCorrectCount() throws Exception {
        try {
            // spawn a random number of consumers
            Map<RequestType, Collection<MessageListenerContainer>> consumersForType = new HashMap<>();
            Map<RequestType, Long> consumerCountForType = new HashMap<>();
            for (RequestType type : RequestType.values()) {
                List<MessageListenerContainer> consumers = new ArrayList<>();
                consumersForType.put(type, consumers);
                consumerCountForType.put(type, ThreadLocalRandom.current().nextLong(10, 20 + 1));
                for (long i = 0; i < consumerCountForType.get(type); i++) {
                    consumers.add(spawnConsumer("dst." + type.toString().toLowerCase()));
                }
            }

            // wait for rabbit to get to know the new consumers
            Thread.sleep(2000);

            Map<RequestType, Long> workerCount = workloadMonitor.getWorkerCount();

            // stop all consumers
            consumersForType.entrySet().stream().map(Map.Entry::getValue).flatMap(Collection::stream).forEach(Lifecycle::stop);

            assertThat(workerCount, equalTo(consumerCountForType));
        } finally {
            workloadMonitor.close();
        }
    }

    private MessageListenerContainer spawnConsumer(String queue) {
        SimpleMessageListenerContainer container = new SimpleMessageListenerContainer(rabbit.getConnectionFactory());
        container.addQueueNames(queue);
        container.start();
        return container;
    }

    @Test
    public void close_removesQueues() throws Exception {
        workloadMonitor.close();

        List<org.springframework.amqp.core.Queue> queues = rabbit.getManager().getQueues();
        long nonWorkQueues = rabbit.getManager().getQueues().stream().filter(q -> !Arrays.asList(WORK_QUEUES).contains(q.getName())).count();
        assertThat(nonWorkQueues, is(0L));
    }
}