FlexAttention: PyTorch의 유연성과 FlashAttention의 성능
-
제목: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention"
-
주제: FlexAttention을 소개하며, PyTorch의 유연성과 FlashAttention의 성능을 결합한 새로운 API
-
작성자: Team PyTorch (Horace He, Driss Guessous, Yanbo Liang, Joy Dong)
-
주요 내용:
- 플렉스어텐션(FlexAttention)은 다양한 어텐션 변형을 몇 줄의 PyTorch 코드로 구현할 수 있는 유연한 API 제공
- torch.compile을 통해 플렉스어텐션을 플래시어텐션(FlashAttention) 커널로 낮추어 추가 메모리 없이 고성능 구현
- 역전파를 자동으로 생성하고 주어진 어텐션 마스크의 희소성을 활용하여 성능을 향상
-
주요 기능:
- 사용자 정의 함수 score_mod을 통한 어텐션 점수 수정 가능
- 블록 마스크(BlockMask)를 사용하여 희소성을 최대한 활용
-
예시:
- 전방향 어텐션은 수행에 있어 score_mod가 필요하지 않음
- 상대적 위치 인코딩, ALiBi 바이어스, 소프트캡핑, 인과 마스크 등 다양한 어텐션 변형 예시 제공
-
성능:
- FlexAttention은 사용자 정의 트리톤 커널과 비슷한 성능을 발휘
- FlashAttention2와 비교하여 전방향 계산에서 90%, 역전파 계산에서 85% 성능 달성
-
결론:
- FlexAttention을 통해 연구자들이 새로운 어텐션 변형을 손쉽게 시도할 수 있게 되기를 기대
- PyTorch의 다양한 인프라와 연계하여 구현의 재미를 더함
-
향후 작업:
- FlexAttention은 현재 PyTorch 야간 릴리스에서 사용 가능하며, 2.5.0에서 프로토타입 기능으로 출시 예정
- 추론(inference)을 위한 FlexAttention 사용 방법 설명은 추후 제공 예정
- H100 GPU에서 FlashAttention3와 성능을 맞추기 위한 작업 진행 중
- 모든 시퀀스 길이가 128의 배수일 필요가 있는 문제 해결 예정
- GQA 지원 추가 계획
-
감사의 글:
- 트리 다오(Tri Dao)의 FlashAttention 작업
- 프란시스코 마사 및 Xformers 팀
- Jax 팀의 SplashAttention 작업
- Philippe Tillet 및 Keren Zhou의 도움
- Ali Hassani와의 논의
- 어텐션 커널에 대한 불만을 제기한 모든 사람들에게 감사의 뜻 전함
2pytorch.org링크 복사하기
AI 뉴스 요약은 뉴스의 내용을 AI가 요약(GPT-4 활용)한 것입니다. 따라서 틀린 내용을 포함할 수 있습니다. 뉴스의 자세한 내용을 확인하시려면 해당 뉴스 링크를 클릭해주세요.