GitXplorerGitXplorer
m

candle-flash-attn-v3

public
8 stars
0 forks
0 issues

Commits

List of commits on branch main.
Unverified
96bd7513c76a621f59aa1a4e16e6be539381588b

add license

mmichaelfeil committed 11 days ago
Unverified
12189e0f1b374932c37862c0d9bb2faf214f5c21

add gqa param

mmichaelfeil committed 12 days ago
Unverified
29b2cea9b9932376b2faf70f8f0f2b2ef0a2148d

add candle tests

mmichaelfeil committed 14 days ago
Unverified
c2143d33ea06e68a6184426679d669fe20b2084e

remove fp8 (working)

mmichaelfeil committed 15 days ago
Unverified
a53221c036d713a8e1c531ca2a0b64ca6d12da38

add gqaswitch (untested, previous working)

mmichaelfeil committed 15 days ago
Unverified
44f7d011b332a1900eb0944a2f83fbea12521336

rm: backward

mmichaelfeil committed 15 days ago

README

The README file for this repository.

Candle Flash Attention v3 Layer

Flash Attention v3 Layer for Hopper (compatible nvidia sm90a arch) and the candle framework.

Work supported by Baseten (https://github.com/basetenlabs) If you are working on the intersection of CUDA / LLMs and Inference already, feel free to reach out, we are hiring.

Usage

use baseten_candle_flash_attn_v3;
use anyhow::Result;
use candle::{DType, Device, IndexOp, Tensor, D};

fn flash_attn_acausal() -> Result<()> {
    let device = Device::new_cuda(0)?;
    let q = Tensor::arange(0u32, 3 * 2 * 64, &device)?
        .to_dtype(DType::F16)?
        .reshape((1, 3, 2, 64))?; // batch, head, seqlen, hidden_dim
    let k = (&q / 400.)?;
    let v = (&q / 500.)?;
    let q = (&q / 300.)?;

    let att = {
        let q = q.transpose(1, 2)?;
        let k = k.transpose(1, 2)?;
        let v = v.transpose(1, 2)?;
        baseten_candle_flash_attn_v3::flash_attn(&q, &k, &v, 0.5, false, false)?.transpose(1, 2)?
    };

Install instructions

[dependencies]
candle = { version = "*", package = "candle-core", default-features = false }
candle-nn = { version = "*" }
candle-transformers = { version = "*" }
baseten-candle-flash-attn-v3 = { git = "https://github.com/michaelfeil/candle-flash-attn-v3", rev = "main", optional = true }