Monday, 9 May 2016

java, fft, and jvm

After quite a break i finally got stuck into a bit of hacking over the weekend. Actually i literally spent every waking hour of the weekend tinkering on the keyboard, about 25 hours or so. Yes my arse is a bit sore.

I will probably end up writing more about it but for now I'm just going to summarise some findings and results of some tinkering with FFT routines in Java. It may sound oxymoronic but i've been a bit under the weather so this was something "easy" to play with which doesn't require too much hard thought.

Initially i was working on an out-of-place in-order implementation but I hit some problems with 2D transforms so decided to settle on in-place out-of-order instead. Using a decimation in frequency (DIF) for the forward transform and a decimation in time (DIT) for the inverse allows the same routine to be used for all passes with no bit-reversal needed. I only implemented the basic Cooley-Tukey algorithm and only with radix-2 or radix-4 kernels and with special case code for the first (DIT) or last (DIF) pass. I'm only really interested in "video-image-sized" routines at this point.

I normally use jtransforms (2.4) for (java) fft's so i used that as a performance basis. Whilst the out-of-order results aren't identical they are just as useful for my needs so I think it's fair to compare them. In all cases i'm talking about single-threaded performance and complex input/output (jtransforms.complexForward/complexInverse), and i've only looked at single precision so far.

• An all radix-2 implementation is about half the speed of jtransforms but takes very little code.

• A naive radix-4 factorisation (40 flops) is roughly on-par with jtransforms depending on N.

• A better radix-4 factoriation (34 flops) is 10-30% faster across many sizes of interest.

• I can beat jtransforms quite handily at 2D transforms even using the naive radix-4 kernel. e.g. ~60% for 102x1024.

This only required creating a somewhat trivial implementation of a column-specific transform. More on this in another post. It will have some pretty pictures. I'm still working on code for it too.

The performance inflexion point (1D) is about 2^14 elements, beyond that jtransforms pulls (well) ahead. I suspect this is related to the cache size and such minutiae is something i just haven't explored.

Some other java/jvm related observations:

• Single micro-optimisations can make significant (up to 20%) differences to total runtime. Such as:

• Floating point common subexpression elimination;
• Loop arithmetic;
• Memory access ordering (due to java memory model and aliasing rules);
• Using final, or copying a field value to a local variable. Why? Maybe the compiler uses these as a registerisation hint?
• And more.

Memory access ordering is particularly difficult to overcome; there is no 'restrict' keyword or even possibility with the java language. Any changes one makes will be quite platform and compiler specific (although microoptimisations in general are to start with). Nope! Seems i was just a victim of a skinner box. After a check I found there are no relevant rules here and changes i measured were for unrelated internal compiler reasons which are of no practical use. Should've checked before wasting my time.

Floating point arithmetic is also a bit-shit-by-design so expressions can't be freely re-arranged by the compiler for performance (well, depending on how reliable you expect the results to be, i think java aims to be strict). Java currently also has no fused-multiply-add although java-9 has added one. "complex float" would be nice. Although I didn't feel like digging out the disassembler to explore the generated code from hotspot so i'm only guessing as to how it's compiled the code.

• The jvm sometimes "over-optimises" resulting in slower code over time.

I suspect it inlines too much and causes register spills. In some cases I saw 20%+ performance hits although after upgrading to the latest jvm (from u40 to u92) they seem less severe although still present. I used "-XX:CompileCommand=dontinline,*.*" to compare to the default. It showed up as the first-of-3 runs being somewhat faster than subsequent ones, whereas normally it should improve as hotspots are identified. I just thought it was my cpu throttling at first but it continued after i'd locked it.

• On a shithouse register starved architecture like x86 the jvm overheads limit the complexity of routines. e.g. a (naive) radix-8 kernel ran slower than a radix-4 one due to register spills (it wasn't very good either though).

• It's very easy to take a "simple" routine and blow it out into a big library of special cases and experiments for that elusive extra single-digit% in performance gained. Actually i've done a lot of this work before but i didn't bother referencing much of it because there is no goal; only a journey.

Of course 'performance java' isn't exactly a hot topic anymore and i should be focusing on massively-parallel devices but I was only doing it to pass the time and I can tackle OpenCL another wet weekend. If I can identify isolated cases for the the optimiser-degredations I came across I will see what the java forums have to say about them.

I've nearly exhausted my interest for now but i still have a couple more things to try and also to identify what if any of it I want to keep. If I get really keen I will see about some sort of reusable library for special purpose fft use - one problem with jtransforms (and indeed other fft libraries) is that you need to marshal the data into/from the correct layout which can be quite costly and negate much of the performance they purport to provide. Big if on that though.