Docs

DIRECTOR: Generator-Classifiers For Supervised Language Modeling

Kushal Arora, Kurt Shuster, Sainbayar Sukhbaatar, Jason Weston

Paper Link: https://arxiv.org/abs/2206.07694

Abstract

Current language models achieve low perplexity but their resulting generations still suffer from toxic responses, repetitiveness and contradictions. The standard language modeling setup fails to address these issues. In this paper, we introduce a new architecture, DIRECTOR, that consists of a unified generator-classifier with both a language modeling and a classification head for each output token. Training is conducted jointly using both standard language modeling data, and data labeled with desirable and undesirable sequences. Experiments in several settings show that the model has competitive training and decoding speed compared to standard language models while yielding superior results, alleviating known issues while maintaining generation quality. It also outperforms existing model guiding approaches in terms of both accuracy and efficiency.

Safety Experiments Commands:

Train the evaluation classifier:

 parlai train --task projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+neg_only -et projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+neg_only -vtim 120 --model transformer/classifier  --load-from-pretrained-ranker True --init-model zoo:pretrained_transformers/bi_model_huge_reddit/model --dict-file zoo:pretrained_transformers/bi_model_huge_reddit/model.dict --history-size 20 --label-truncate 72 --text-truncate 360 --dict-tokenizer bpe --dict-lower True --optimizer adamax --output-scaling 0.06 --variant xlm --reduction-type mean --share-encoders False --learn-positional-embeddings True --n-layers 12 --n-heads 12 --ffn-size 3072 --attention-dropout 0.1 --relu-dropout 0.0 --dropout 0.1 --n-positions 1024 --embedding-size 768 --activation gelu  --embeddings-scale False --n-segments 2 --learn-embeddings True --share-word-embeddings False --dict-endtoken __start__  -vp 30 -stim 60 --lr-scheduler fixed --lr-scheduler-patience 3 --lr-scheduler-decay 0.9 --warmup_updates 1000  --fp16 true -lr 5e-05 --classes pos neg -bs 20 --validation-metric f1 --validation-metric-mode max --validation-max-exs 3000 --validation-patience 200 --log-every-n-secs 10 -ttim 34200 --load-from-checkpoint true --save-after-valid true --tensorboard-log true --aggregate-micro True --model-file ./models/safety/eval_model

Train the DIRECTOR Model:

parlai train -vtim 300 -bs 6 --gradient-clip 10.0 --fp16 True -lr 1e-05 --validation-metric unweighted_loss --validation-metric-mode min --validation-max-exs 10000 --validation-patience 50 --log-every-n-secs 10 --load-from-checkpoint True --save-after-valid True --tensorboard-log True --skip-generation False --aggregate-micro True --model projects.director.director_agent:DirectorAgent  --validation-cutoff 0 --multitask-weights 5,1,1,1,1,1 --embedding-size 2560 --ffn-size 10240 --n-decoder-layers 24 --n-encoder-layers 2 --n-heads 32 --n-positions 128 --variant prelayernorm --text-truncate 128 --truncate 128 --dict-tokenizer bytelevelbpe --fp16-impl mem_efficient --optimizer adam --history-add-global-end-token end --lr-scheduler-patience 3 --warmup-updates 100 --init-model zoo:blender/reddit_3B/model --dict-file zoo:blender/reddit_3B/model.dict --model-parallel True -t blended_skill_talk:mutators=flatten,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY -et blended_skill_talk:mutators=flatten,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY --train-gamma 3.0 --model-file ./models/safety/director_model

Evaluate DIRECTOR model on the toxic prompts from WikiToxicComments dataset:

python -m parlai.scripts.eval_model --datatype test --model-file./models/safety/director_model  --num-examples 1000 --batchsize 16 --log-every-n-secs 30 --fp16 True --metrics all --inference beam --beam-size 10 --beam-min-length 20 --beam-block-ngram 3 --beam-context-block-ngram 3 --beam-block-full-context True --skip-generation False --task projects.director.tasks.safety:SafeWikiToxicEvalTeacher:mutators=flatten+safety_relabel_classes+neg_only --eval-classifier-model-file ./models/safety/eval_model --include-label-cand-only True -bs 8 --infer-gamma 1