Swapping Out Transformer Subcomponents¶
Author: Spencer Poff
Sometimes you find yourself wanting to experiment with an architecture that looks a lot like another, but with one component modified. If that component is buried deep within the model, this is not easily accomplished with subclassing without copying and pasting much of the original implementation.
To make this easier and avoid copypasta, we provide the
Making a Module Swappable¶
Let’s say you have an existing class,
TransformerLayer, that uses a module that you’d like to modify,
TransformerFFN. You can make that FFN swappable in two steps:
@swappable, passing in a name for the component you’d like to swap and its default class/constructor:
@swappable(ffn=TransformerFFN) class TransformerLayer(nn.Module): ...
At runtime, the class for ffn will be added to a property
TransformerLayer. Replace your instantiation of
TransformerFFNwith a call to that constructor:
self.feedforward = self.swappables.ffn(opt, ...)
Making the Swap¶
You can now replace
TransformerFFN with whatever class or constructor you want before instantiating
layer = TransformerLayer.with_components(ffn=NewCustomFFN)(opt, ...)
As long as
NewCustomFFN has the same
forward method signatures as
TransformerFFN, everything should just work.
Since the swapping happens before instantiation, decorated components can be transparently composed. For example:
model = TransformerGeneratorModel.with_components( encoder=TransformerEncoder.with_components( layer=TransformerEncoderLayer.with_components( self_attention=MultiHeadAttention, feedforward=TransformerFFN, ) ), decoder=TransformerDecoder.with_components( layer=TransformerDecoderLayer.with_components( encoder_attention=MultiHeadAttention, self_attention=MultiHeadAttention, feedforward=TransformerFFN, ) ), )(opt=self.opt, dictionary=self.dict)