Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

What does __syncthreads() do here exactly? I'm new to CUDA, could get the overall idea of the FlashAttention paper but not the details.


Causes every thread in the block to wait until they have reached this point. Worth reading a cuda primer for more details on blocks/warps.

Since the threads are relying on each other to fill the SRAM with all needed data if you didn’t wait then values would be missing.


Any CUDA primer you recommend in particular? I had this same question.


Here's an article on syncing in CUDA via cooperative groups: https://developer.nvidia.com/blog/cooperative-groups/

There's also explicit warp synchronization, i.e. __syncwarp(). More on warp primitives here: https://developer.nvidia.com/blog/using-cuda-warp-level-prim...


Probably https://www.youtube.com/watch?v=nOxKexn3iBo (or just skimming the attached colab).


This is terrific, thanks!




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: