검색
검색
AI news 검색
회원가입로그인

FlexAttention: PyTorch의 유연성과 FlashAttention의 성능

  • 제목: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention"

  • 주제: FlexAttention을 소개하며, PyTorch의 유연성과 FlashAttention의 성능을 결합한 새로운 API

  • 작성자: Team PyTorch (Horace He, Driss Guessous, Yanbo Liang, Joy Dong)

  • 주요 내용:

    • 플렉스어텐션(FlexAttention)은 다양한 어텐션 변형을 몇 줄의 PyTorch 코드로 구현할 수 있는 유연한 API 제공
    • torch.compile을 통해 플렉스어텐션을 플래시어텐션(FlashAttention) 커널로 낮추어 추가 메모리 없이 고성능 구현
    • 역전파를 자동으로 생성하고 주어진 어텐션 마스크의 희소성을 활용하여 성능을 향상
  • 주요 기능:

    • 사용자 정의 함수 score_mod을 통한 어텐션 점수 수정 가능
    • 블록 마스크(BlockMask)를 사용하여 희소성을 최대한 활용
  • 예시:

    • 전방향 어텐션은 수행에 있어 score_mod가 필요하지 않음
    • 상대적 위치 인코딩, ALiBi 바이어스, 소프트캡핑, 인과 마스크 등 다양한 어텐션 변형 예시 제공
  • 성능:

    • FlexAttention은 사용자 정의 트리톤 커널과 비슷한 성능을 발휘
    • FlashAttention2와 비교하여 전방향 계산에서 90%, 역전파 계산에서 85% 성능 달성
  • 결론:

    • FlexAttention을 통해 연구자들이 새로운 어텐션 변형을 손쉽게 시도할 수 있게 되기를 기대
    • PyTorch의 다양한 인프라와 연계하여 구현의 재미를 더함
  • 향후 작업:

    • FlexAttention은 현재 PyTorch 야간 릴리스에서 사용 가능하며, 2.5.0에서 프로토타입 기능으로 출시 예정
    • 추론(inference)을 위한 FlexAttention 사용 방법 설명은 추후 제공 예정
    • H100 GPU에서 FlashAttention3와 성능을 맞추기 위한 작업 진행 중
    • 모든 시퀀스 길이가 128의 배수일 필요가 있는 문제 해결 예정
    • GQA 지원 추가 계획
  • 감사의 글:

    • 트리 다오(Tri Dao)의 FlashAttention 작업
    • 프란시스코 마사 및 Xformers 팀
    • Jax 팀의 SplashAttention 작업
    • Philippe Tillet 및 Keren Zhou의 도움
    • Ali Hassani와의 논의
    • 어텐션 커널에 대한 불만을 제기한 모든 사람들에게 감사의 뜻 전함

2pytorch.org링크 복사하기
AI 뉴스 요약은 뉴스의 내용을 AI가 요약(GPT-4 활용)한 것입니다. 따라서 틀린 내용을 포함할 수 있습니다. 뉴스의 자세한 내용을 확인하시려면 해당 뉴스 링크를 클릭해주세요.
원본 뉴스 보기