GitXplorerGitXplorer
y

hf-torch-compile-benchmark

public
3 stars
0 forks
0 issues

Commits

List of commits on branch main.
Verified
9d45e22ff84ee19423ed415d359ec4c32114cfa7

Update main.py

yyounesbelkada committed 2 years ago
Verified
44526c0366ee757f483e10855c83efa1c7df00f4

Update README.md

ppatrickvonplaten committed 2 years ago
Unverified
ba697c3e46fe99d6a6d95302ee12eade7c06b3f8

up

ppatrickvonplaten committed 2 years ago
Unverified
f95015e1438321fd35ad68a0e2bbae52b6db58c2

Merge branch 'main' of https://github.com/younesbelkada/hf-torch-compile-benchmark

ppatrickvonplaten committed 2 years ago
Unverified
b5070124a8353b8321d8e05ca9219dd00a37f53a

improve

ppatrickvonplaten committed 2 years ago
Verified
cafa162461c6628943424d37b08cda843543a7b7

Hotfix requirements

yyounesbelkada committed 2 years ago

README

The README file for this repository.

hf-torch-compile-benchmark 🔥

A repository to benchmark the expected speedups using torch.compile and torch.scaled_dot_product_attention

How to use it?

First, the project for now depends on the main branch of transformers and optimum, and you need at least torch>=2.0.

pip install -r requirements.txt

Use main.py to benchmark the speedups. Run python main.py -h to understand how to use the benchmarking script. The results will be added by default inside output.csv file in the current directory. If that file already exists, the results will be appended to it.

Better to use a shell script to loop over various combination of hyper parameters (batch_size, seq_len, etc.) and append the results to the same file.

Benchmarks

RTX4090

pt_version model_name compile_mode batch_size max_num_tokens run_type precision hf_time sdpa_no_compile_time sdpa_compile_time speedup_sdpa+compile speedup_sdpa problems
2.0.0+cu118 gpt2 reduce-overhead 1 256 forward-only torch.float16 0.00426 0.00274 0.00126 238.10% 55.47%
2.0.0+cu118 gpt2 reduce-overhead 8 256 forward-only torch.float16 0.00819 0.00817 0.00615 33.17% 0.24%
2.0.0+cu118 gpt2 reduce-overhead 32 256 forward-only torch.float16 0.03371 0.0316 0.02269 48.57% 6.68%
2.0.0+cu118 gpt2 reduce-overhead 1 256 forward-only torch.float32 0.00431 0.00327 0.00286 50.70% 31.80%
2.0.0+cu118 gpt2 reduce-overhead 8 256 forward-only torch.float32 0.01882 0.01907 0.01633 15.25% -1.31%
2.0.0+cu118 gpt2 reduce-overhead 32 256 forward-only torch.float32 0.08607 0.08528 0.0616 39.72% 0.93%
2.0.0+cu118 t5-base reduce-overhead 1 256 forward-only torch.float16 0.01098 0.01114 0.00588 86.73% -1.44%
2.0.0+cu118 t5-base reduce-overhead 8 256 forward-only torch.float16 0.02174 0.02028 0.01717 26.62% 7.20%
2.0.0+cu118 t5-base reduce-overhead 32 256 forward-only torch.float16 0.0965 0.07942 0.07 37.86% 21.51%
2.0.0+cu118 t5-base reduce-overhead 1 256 forward-only torch.float32 0.01473 0.01053 0.00916 60.81% 39.89%
2.0.0+cu118 t5-base reduce-overhead 8 256 forward-only torch.float32 0.03639 0.03587 0.03302 10.21% 1.45%
2.0.0+cu118 t5-base reduce-overhead 32 256 forward-only torch.float32 0.1413 0.14658 0.13006 8.64% -3.60%
2.0.0+cu118 gpt2 reduce-overhead 1 256 generate torch.float16 0.87543 0.67787 0.6773 29.25% 29.14% <- probably also problems with k/v cache
2.0.0+cu118 gpt2 reduce-overhead 8 256 generate torch.float16 0.93707 0.78795 0.7868 19.10% 18.93% <- probably also problems with k/v cache
2.0.0+cu118 gpt2 reduce-overhead 32 256 generate torch.float16 1.22092 0.85482 0.85341 43.06% 42.83% <- probably also problems with k/v cache
2.0.0+cu118 gpt2 reduce-overhead 1 256 generate torch.float32 0.90596 0.66562 0.66414 36.41% 36.11% <- probably also problems with k/v cache
2.0.0+cu118 gpt2 reduce-overhead 8 256 generate torch.float32 0.97111 0.82092 0.82009 18.42% 18.30% <- probably also problems with k/v cache
2.0.0+cu118 gpt2 reduce-overhead 32 256 generate torch.float32 1.54068 1.36056 1.36055 13.24% 13.24% <- probably also problems with k/v cache
2.0.0+cu118 t5-base reduce-overhead 1 256 generate torch.float16 0.58538 0.59828 0.59933 -2.33% -2.16% <- problems with k/v cache
2.0.0+cu118 t5-base reduce-overhead 8 256 generate torch.float16 0.64183 0.65154 0.65241 -1.62% -1.49% <- problems with k/v cache
2.0.0+cu118 t5-base reduce-overhead 32 256 generate torch.float16 0.67339 0.67719 0.67812 -0.70% -0.56% <- problems with k/v cache
2.0.0+cu118 t5-base reduce-overhead 1 256 generate torch.float32 0.49707 0.53413 0.53497 -7.08% -6.94% <- problems with k/v cache
2.0.0+cu118 t5-base reduce-overhead 8 256 generate torch.float32 0.54263 0.5798 0.58015 -6.47% -6.41% <- problems with k/v cache
2.0.0+cu118 t5-base reduce-overhead 32 256 generate torch.float32 0.60027 0.64 0.63894 -6.05% -6.21% <- problems with k/v cache
2.1.0.dev20230504+cu118 gpt2 reduce-overhead 1 256 forward-only torch.float16 0.00426 0.003 0.05625 -92.43% 42.00% <- problems with torch compile
2.1.0.dev20230504+cu118 gpt2 reduce-overhead 8 256 forward-only torch.float16 0.00838 0.00842 0.06906 -87.87% -0.48% <- problems with torch compile
2.1.0.dev20230504+cu118 gpt2 reduce-overhead 32 256 forward-only torch.float16 0.03402 0.03188 0.09866 -65.52% 6.71% <- problems with torch compile
2.1.0.dev20230504+cu118 gpt2 reduce-overhead 1 256 forward-only torch.float32 0.00463 0.00345 0.05444 -91.50% 34.20% <- problems with torch compile
2.1.0.dev20230504+cu118 gpt2 reduce-overhead 8 256 forward-only torch.float32 0.01917 0.01945 0.10395 -81.56% -1.44% <- problems with torch compile
2.1.0.dev20230504+cu118 gpt2 reduce-overhead 32 256 forward-only torch.float32 0.08936 0.08605 0.13856 -35.51% 3.85% <- problems with torch compile
2.1.0.dev20230504+cu118 t5-base reduce-overhead 1 256 forward-only torch.float16 0.01189 0.0121 0.00609 95.24% -1.74%
2.1.0.dev20230504+cu118 t5-base reduce-overhead 8 256 forward-only torch.float16 0.02182 0.0203 0.01801 21.15% 7.49%
2.1.0.dev20230504+cu118 t5-base reduce-overhead 32 256 forward-only torch.float16 0.09746 0.08032 0.08354 16.66% 21.34%
2.1.0.dev20230504+cu118 t5-base reduce-overhead 1 256 forward-only torch.float32 0.01053 0.01124 0.00929 13.35% -6.32%
2.1.0.dev20230504+cu118 t5-base reduce-overhead 8 256 forward-only torch.float32 0.03523 0.03615 0.03444 2.29% -2.54%
2.1.0.dev20230504+cu118 t5-base reduce-overhead 32 256 forward-only torch.float32 0.13752 0.1446 0.12911 6.51% -4.90%
2.1.0.dev20230504+cu118 huggyllama/llama-7b reduce-overhead 1 256 forward-only torch.float16 0.04381 0.02948 0.23042 -80.99% 48.61% <- problems with torch compile
2.1.0.dev20230504+cu118 EleutherAI/gpt-j-6b reduce-overhead 1 256 forward-only torch.float16 0.02678 0.0247 0.19665 -86.38% 8.42% <- problems with torch compile

To reproduce:

TODOs

  • properly deal with attention masks
  • script to nicely render the results