]> git.somenet.org - pub/astra/parallel.git/blob - openmp/merge/sort.c
update prefix sums
[pub/astra/parallel.git] / openmp / merge / sort.c
1 /* implementation of parallel merge sort */
2 #include <stdio.h>
3 #include <omp.h>
4 #include <stdlib.h>
5 #include <unistd.h>
6 #include "numlist.h"
7
8 #define ATYPE int
9
10 void printlist(char * message, ATYPE * ptr, int len);
11 int rank(int elem, ATYPE * list, int len);
12 /* merge(a,n,b,m,c): merges a of size n and b of size m into c */
13 void merge(int ti, ATYPE * a, int n, ATYPE * b, int m, ATYPE * c);
14
15 int main ( int argc, char ** argv) {
16         int n = LISTSIZE, m = LISTSIZE;
17         int p = 1;
18         int opt, i;
19         int a_len, b_len, b_len_end, b_len_begin;
20         ATYPE * c;
21         double startTime, endTime;
22
23         while ((opt = getopt(argc, argv, "t:")) != -1) {
24                 switch (opt) {
25                         case 't':
26                                 p = atoi(optarg);
27                                 break;
28                         default: /* '?' */
29                                 fprintf(stderr, "Usage: %s [-t threads]\n", argv[0]);
30                                 exit(EXIT_FAILURE);
31                 }
32         }
33
34         //printf ("Maximal number of threads: %i\n", omp_get_max_threads());
35         //printf ("----------------------------------\n");
36
37         c = (ATYPE *) malloc((LISTSIZE * 2 + 1) * (sizeof (ATYPE)));
38         c[LISTSIZE*2] = -1;
39
40         //printlist("0 Sorted List A:", a);
41         //printlist("0 Sorted List B:", b);
42
43         startTime = omp_get_wtime();
44
45         a_len = n/p;
46         #pragma omp parallel for shared(a,b,c,n,m,p,a_len) private(i,b_len_begin,b_len_end,b_len)
47         for (i = 0; i < p; i++) {
48                 b_len_begin = rank(a[i*a_len], b, m);
49                 b_len_end = rank(a[(i+1)*a_len], b, m);
50                 if (b_len_begin < 0) {
51                         printf ("Insert to end of list!\n");
52                         b_len_begin = n;
53                 }
54                 if (b_len_end < 0) {
55                         //printf ("Reached end of list!\n");
56                         b_len_end = n;
57                 }
58                 b_len = b_len_end - b_len_begin;
59                 //printf ("%i a_len: %i, b_len: %i (begin:%i ([%i]) -> end:%i [%i])\n", i+1, a_len, b_len, b_len_begin, b[b_len_begin], b_len_end, b[b_len_end]);
60                 merge(  i+1,
61                                 &a[i*a_len],
62                                 a_len,
63                                 &b[b_len_begin],
64                                 b_len,
65                                 &c[i*a_len+b_len_begin]);
66         }
67
68         endTime = omp_get_wtime();
69         printf("took %f seconds.\n", endTime-startTime);
70
71         //printlist("Sorted List:", c);
72         free(c);
73         return 0;
74 }
75
76 void printlist(char * message, ATYPE * ptr, int len) {
77         printf (message);
78         while (len > 0) {
79                 printf (" %i", *ptr);
80                 ptr++; len--;
81         }
82         printf ("\n");
83 }
84
85 int rank(ATYPE elem, ATYPE * list, int len) {
86         int pos;
87         int i = len;
88
89         pos = 0;
90         if (elem == -1) { return -1; }
91         while (i > 0) {
92                 /*printf ("elem_list: %i %i\n", elem, *list);*/
93                 if (elem <= *list) {
94                         return pos;
95                 }
96                 list++;
97                 pos++;
98                 i--;
99         }
100         return len;
101 }
102
103 void merge(int ti, ATYPE * a, int n, ATYPE * b, int m, int * c) {
104         int sum;
105         int i;
106         /*printf ("sorting a:%i (%i) and b:%i (%i)\n", *a, n, *b, m);*/
107         if (m<0) { m=0;}
108         if (n<0) { n=0;}
109         sum = n + m;
110         //printf ("%i modifying %i (%i+%i) (c[%i]/0x%08x -> c[%i]/0x%08x)\n", ti, sum, n, m, 0, (unsigned int) &c, sum-1, ((unsigned int) &c)+sum);
111         for (i = 0; i < sum; i++) {
112                 // n+m == 0
113                 if (n <= 0 && m <= 0) {
114                         printf ("%i Calculation failed somehow...\n", ti);
115                         return;
116                 }
117                 if (n <= 0) {
118                         c[i] = *b++; m--;
119                 } else {
120                         if (m <= 0) {
121                                 c[i] = *a++; n--;
122                         } else {
123                                 if (*a < *b) {
124                                         c[i] = *a++; n--;
125                                 } else {
126                                         c[i] = *b++; m--;
127                                 }
128                         }
129                 }
130         }
131         // printf ("merge done, n=%d, m=%d\n", n, m);
132         return;
133 }