]> git.somenet.org - pub/jan/dst18.git/blob - ass3-event/src/test/java/dst/ass3/event/StaticQueueSink.java
Make ElasticityControllerTest more tolernat to edge cases
[pub/jan/dst18.git] / ass3-event / src / test / java / dst / ass3 / event / StaticQueueSink.java
1 package dst.ass3.event;
2
3 import java.util.ArrayList;
4 import java.util.HashMap;
5 import java.util.List;
6 import java.util.Map;
7 import java.util.concurrent.BlockingQueue;
8 import java.util.concurrent.ConcurrentHashMap;
9 import java.util.concurrent.LinkedBlockingQueue;
10 import java.util.concurrent.TimeUnit;
11
12 import org.apache.flink.streaming.api.functions.sink.SinkFunction;
13
14 /**
15  * A SinkFunction that collects objects into a queue located in a shared global state. Each collector accesses a
16  * specific key in the shared state.
17  *
18  * @param <T> the sink input type
19  */
20 public class StaticQueueSink<T> implements SinkFunction<T> {
21
22     private static final long serialVersionUID = -3965500756295835669L;
23
24     private static Map<String, BlockingQueue<?>> state = new ConcurrentHashMap<>();
25
26     private String key;
27
28     public StaticQueueSink(String key) {
29         this.key = key;
30     }
31
32     @Override
33     public void invoke(T value, Context context) throws Exception {
34         get().add(value);
35     }
36
37     public void clear() {
38         get().clear();
39     }
40
41     public List<T> take(int n) {
42         List<T> list = new ArrayList<>(n);
43
44         for (int i = 0; i < n; i++) {
45             try {
46                 list.add(get().take());
47             } catch (InterruptedException e) {
48                 throw new RuntimeException("Interrupted while accessing queue", e);
49             }
50         }
51
52         return list;
53     }
54
55     public T take() {
56         try {
57             return get().take();
58         } catch (InterruptedException e) {
59             return null;
60         }
61     }
62
63     public T poll(long ms) {
64         try {
65             return get().poll(ms, TimeUnit.MILLISECONDS);
66         } catch (InterruptedException e) {
67             return null;
68         }
69     }
70
71     public synchronized BlockingQueue<T> get() {
72         return get(key);
73     }
74
75     @SuppressWarnings("unchecked")
76     private static <R> BlockingQueue<R> get(String key) {
77         return (BlockingQueue) state.computeIfAbsent(key, k -> new LinkedBlockingQueue<>());
78     }
79
80     public static void clearAll() {
81         state.clear();
82     }
83 }