Wednesday 25 December 2013

parallel batch sorting

Had what was supposed to be a 'quick look' at converting the GA code to OpenCL from the last post. Got stuck at the first hurdle - sorting the classifier scores.

I wasn't going to be suckered into working on a 'good' solution, just one that worked. But the code I had just didn't seem fast enough compared to single-threaded Java Arrays.sort() (only 16x) so I ended up digging deeper.

For starters I was using Batcher's sort, which although it is parallel it turns out it doesn't utilise the stream units very well in all passes and particuarly badly in the naive way I had extended it to sorting larger numbers of items than the number of work items in the work-group. I think I just used Batchers because I got it to work at some point, and IIRC took it from Knuth.

So I looked into bitonic sort - unfortunately the "pseudocode" on the wikipedia page is pretty much useless because it has some recursive python - which is about as useful for demonstrating massive concurrency as a used bit of toilet paper. The pictures helped though.

So I started from first principles and wrote down the comparison pairs as binary and worked out how to reasonably cheaply create compare-and-swap index pairs from sequential work item id's.

For an 8-element sort, the compare-and-swap pairs are:

 pass,step 0,1
  pair  binary   flip comparison
   0-1  000 001  0
   2-3  010 011  1
   4-5  100 101  0
   6-7  110 111  1

 pass 1,2
  pair  binary   flip comparison
   0-2  000 010  0
   1-3  001 011  0
   4-6  100 110  1
   5-7  101 111  1
 (pass 1,1 is same as pass 0,1)

 pass 2,4
  pair  binary   flip comparison
   0-4  000 100  0
   1-5  001 101  0
   2-6  010 110  0
   3-7  011 111  0
 (pass 2,2 is same as pass 1,2)
 (pass 1,1 is same as pass 0,1)
One notices that bit 'pass' is all 0 for one element of comparison, and all 1 for the other one. The other bits simply increment naturally and are just the work item index with a bit inserted. The flip rule is always the first bit after the inserted bit.

In short:

  lid = local work item id;
  array = base address of array;
  for (bit=0; (1<<bit) < N; bit++) {
    for (b=bit; b >= 0; b--) {
      // insertion masks for bits
      int upper = ~0 << (b+1);
      int lower = (1 << b)-1;
      // extract 'flip' indicator - 1 means use opposite sense
      int flip = (lid >> bit) & 1;
      // insert a 0 bit at bit position 'b' into the index
      int p0 = ((lid << 1) & upper) | (lid & lower);
      // insert a 1 bit at the same place
      int p1 = p0 | (1 << b);

      swap_if(array, p0, p1, compare(p0, p1) ^ flip);
    }
  }

The nice thing about this is that all accesses are sequential within the given block, and the sort sizes decrement within each sub-group. Which leads to the next little optimisation ...

Basically once the block size (1<<b) is the same size as the local work group instead of going wide the work topology changes to running in blocks of work-size across. This allows one to finish all the lower stages entirerly via LDS memory. Although this approach is also much more 'cache friendly' than full passes at each offset, LDS just has a much higher bandwidth than cache. Each block is acessed through the LDS via a workgroup-wide coalesced read/write and an added bonus is that for the larger sized blocks all memory accesses are fully sequential across the whole work-group - i.e. are also fully coalesced. It might even be a reasonable merge sort algorithm for a serial processor: although I think the address arithmetic will be too expensive in that case unless the given cpu has an insert-bit instruction that makes room at the same time.

On closer inspection there may be another minor optimisation for the first LDS based stage - it could perform the previous stage in two parts whilst still retaining full-width reads using a register.

I'm testing by sorting 200 arrays, each with 8192 integers (which is what i need for the task at hand). I'm using a single workgroup per array. Radeon HD 7970 original version.

  algorithm         time

  batchers          8.1ms
  bitonic           7.6ms
  bitonic + lds     2.7ms

  java sort        80.0ms  (1 thread, initialised with random values)

Nice!

I think? Sorting isn't something i've looked into very deeply on a GPU so the numbers might be completely mediocre. In any event it's interesting that the bitonic sort didn't really make much difference until I utilised the LDS because although it requires fewer comparisions it still needs the same number of passes over the data. Apart from that I couldn't see a way to utilise LDS effectively with the Batchers' sort algorithm I was using whereas with bitonic it just falls out automagically.

Got a roast lamb in the oven and cracked a bottle of red, so i'm off outside to enjoy the perfect xmas weather - it's sunny, 30-something, with a very light breeze. I had enough of family this year so i'm going it alone - i'll catch up with friends another day and family in a couple of weeks.

Merry X-mas!

Update: See also a follow-up where I explore extension to non-powers of two elements.

No comments: