GitXplorerGitXplorer
f

accelerated-pytorch-transformers-generation

public
6 stars
2 forks
4 issues

Commits

List of commits on branch main.
Verified
cd00bdbf4f549dbc1c25d834822aafa3718cefb8

Print exception on crash (#13)

ggante committed 2 years ago
Verified
1a43bb7a006f2d1c4e587e73f25091c6e74b398e

plot sweep (#12)

ggante committed 2 years ago
Verified
9b80a35cd31daab5a6756183a80094e617b7f9b7

Update README.md

ffxmarty committed 2 years ago
Verified
cff4a09323048565961b26252183c947b2d8c51b

Add `--profile` flag (#9)

ggante committed 2 years ago
Verified
58034c5e6dba370c22a062e8222ee7954d8fa275

`rotate_half` without `cat`

ggante committed 2 years ago
Unverified
ea3fc7c602895b0365b131e22f90fb1d5dc2bb6a

rotate_half without cat

ggante committed 2 years ago

README

The README file for this repository.

Install

pip install -e .

Running LLAMA

Below on AMD EPYC 7R32 + A10G (g5.2xlarge).

Running default transformers model & generation:

python run_llama.py --model huggingface/llama-7b

Adding flags will change the behavior of text generation (use --help for the available flags):

python run_llama.py --model huggingface/llama-7b --preallocate --compile no
python run_llama.py --model huggingface/llama-7b --preallocate --compile static

You can profile a short run with --profile, with the TB logs being stored in ./tb_logs/

python run_llama.py --model huggingface/llama-7b --preallocate --profile

Results

Running the command above with batch_size=1, prompt_length=1000, new_tokens=200, cache_length=1200, dtype=fp16:

changes compile tok_per_s max_mem_mb hash commit
None no 23.150 14776.09 0d6aa042 /
Preallocated KV cache + SDPA + shared key/value linear no 27.329 14249.72 0d6aa042 300840e4a6531d44d7129d341b6a24cf63947807
above + preallocated attention_mask no 27.377 14247.73 0d6aa042 67a933cb02def42f1fe98cc57d5077b976f1f51f
above + shared query/key/value linear no 27.444 14247.79 0d6aa042 f2e5881e8cf6d0e89f35356ff745e8bb02cb7ebc
above + valid_past_index as tensor + removed controlflows no 27.166 14248.19 0d6aa042 83ca672ec3c0f2c93e70da6d79bafdeb7c2f7e90
above yes (dynamic=False) 29.139 14223.17 0d6aa042 9c51dc0f10df27189141b1f98823ffba214f7e08
above + avoid torch.cat in rotate_half yes (dynamic=False) 29.385 14223.17 0d6aa042 cff4a09323048565961b26252183c947b2d8c51b

The hash is used to "make sure" the implementation is on par with transformers.

The default

BATCH_SIZES = [1]
PROMPT_LENGTHS = [1000]
NEW_TOKENS = [200]

can be edited to run a sweep, for example:

BATCH_SIZES = [1, 2, 4, 8]
PROMPT_LENGTHS = [500, 1000, 4000]
NEW_TOKENS = [1000]

Predefined sweeps

You can sweep over predefined configurations of batch sizes (for a fixed prompt length) and prompt lengths (for a fixed batch size) with the --sweep flag, e.g.

python scripts/run_llama.py --model huggingface/llama-7b --sweep batch

If you run the sweep for the multiple generation alternatives (original code, with preallocated tensors, and preallocated + compiled), you can easily compare the results with

python scripts/plot_results.py --sweep batch