jadore801120/attention-is-all-you-need-pytorch
{ "createdAt": "2017-06-14T10:15:20Z", "defaultBranch": "master", "description": "A PyTorch implementation of the Transformer model in \"Attention is All You Need\".", "fullName": "jadore801120/attention-is-all-you-need-pytorch", "homepage": "", "language": "Python", "name": "attention-is-all-you-need-pytorch", "pushedAt": "2024-04-16T07:27:13Z", "stargazersCount": 9523, "topics": [ "attention", "attention-is-all-you-need", "deep-learning", "natural-language-processing", "nlp", "pytorch" ], "updatedAt": "2025-11-27T04:14:37Z", "url": "https://github.com/jadore801120/attention-is-all-you-need-pytorch"}Attention is all you need: A Pytorch Implementation
Section titled “Attention is all you need: A Pytorch Implementation”This is a PyTorch implementation of the Transformer model in “Attention is All You Need” (Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, arxiv, 2017).
A novel sequence to sequence framework utilizes the self-attention mechanism, instead of Convolution operation or Recurrent structure, and achieve the state-of-the-art performance on WMT 2014 English-to-German translation task. (2017/06/12)
The official Tensorflow Implementation can be found in: tensorflow/tensor2tensor.
To learn more about self-attention mechanism, you could read “A Structured Self-attentive Sentence Embedding”.
The project support training and translation with trained model now.
Note that this project is still a work in progress.
BPE related parts are not yet fully tested.
If there is any suggestion or error, feel free to fire an issue to let me know. :)
WMT’16 Multimodal Translation: de-en
Section titled “WMT’16 Multimodal Translation: de-en”An example of training for the WMT’16 Multimodal Translation task (http://www.statmt.org/wmt16/multimodal-task.html).
0) Download the spacy language model.
Section titled “0) Download the spacy language model.”# conda install -c conda-forge spacypython -m spacy download enpython -m spacy download de1) Preprocess the data with torchtext and spacy.
Section titled “1) Preprocess the data with torchtext and spacy.”python preprocess.py -lang_src de -lang_trg en -share_vocab -save_data m30k_deen_shr.pkl2) Train the model
Section titled “2) Train the model”python train.py -data_pkl m30k_deen_shr.pkl -log m30k_deen_shr -embs_share_weight -proj_share_weight -label_smoothing -output_dir output -b 256 -warmup 128000 -epoch 4003) Test the model
Section titled “3) Test the model”python translate.py -data_pkl m30k_deen_shr.pkl -model trained.chkpt -output prediction.txt[(WIP)] WMT’17 Multimodal Translation: de-en w/ BPE
Section titled “[(WIP)] WMT’17 Multimodal Translation: de-en w/ BPE”1) Download and preprocess the data with bpe:
Section titled “1) Download and preprocess the data with bpe:”Since the interfaces is not unified, you need to switch the main function call from
main_wo_bpetomain.
python preprocess.py -raw_dir /tmp/raw_deen -data_dir ./bpe_deen -save_data bpe_vocab.pkl -codes codes.txt -prefix deen2) Train the model
Section titled “2) Train the model”python train.py -data_pkl ./bpe_deen/bpe_vocab.pkl -train_path ./bpe_deen/deen-train -val_path ./bpe_deen/deen-val -log deen_bpe -embs_share_weight -proj_share_weight -label_smoothing -output_dir output -b 256 -warmup 128000 -epoch 4003) Test the model (not ready)
Section titled “3) Test the model (not ready)”- TODO:
- Load vocabulary.
- Perform decoding after the translation.
Performance
Section titled “Performance”Training
Section titled “Training”
- Parameter settings:
- batch size 256
- warmup step 4000
- epoch 200
- lr_mul 0.5
- label smoothing
- do not apply BPE and shared vocabulary
- target embedding / pre-softmax linear layer weight sharing.
Testing
Section titled “Testing”- coming soon.
- Evaluation on the generated text.
- Attention weight plot.
Acknowledgement
Section titled “Acknowledgement”- The byte pair encoding parts are borrowed from subword-nmt.
- The project structure, some scripts and the dataset preprocessing steps are heavily borrowed from OpenNMT/OpenNMT-py.
- Thanks for the suggestions from @srush, @iamalbert, @Zessay, @JulesGM, @ZiJianZhao, and @huanghoujing.