Description
Fix skip connection,
x_skip = self .layer_norm (x )
attn_output , attn_output_weights = self .multihead_attn (query = x , key = x , value = x )
.
Look into class token, mean pooling for classification,
class ClassificationMLP (nn .Module ):
def __init__ (
self ,
num_classes : int ,
embedding_dim : int ,
patch_count : int ,
) -> None :
super (ClassificationMLP , self ).__init__ ()
self .flatten = nn .Flatten ()
self .class_mapping_head = nn .Linear (in_features = embedding_dim * patch_count , out_features = num_classes )
def forward (self , x : torch .Tensor ) -> torch .Tensor :
x = self .flatten (x )
x = self .class_mapping_head (x )
return x
.
Look into torch parameter for,
self .positions = torch .arange (start = 0 , end = self .num_patches , step = 1 , dtype = torch .int , device = self .device )
.
Merge as sequential,
class VisionTransformerModel (nn .Module ):
def __init__ (
self ,
num_heads : int ,
embedding_dim : int ,
transformer_layers_count : int ,
num_patches : int ,
device : torch .device ,
) -> None :
super (VisionTransformerModel , self ).__init__ ()
self .transformer_layers_count = transformer_layers_count
self .transformer_layers = nn .ModuleList ()
for _ in range (transformer_layers_count ):
self .transformer_layers .append (
TransformerEncoderModel (num_heads = num_heads , embedding_dim = embedding_dim ),
)
self .patch_encoder = PatchPositionEncoder (
num_patches = num_patches ,
embedding_dim = embedding_dim ,
device = device ,
)
def forward (self , x : torch .Tensor ) -> torch .Tensor :
x = self .patch_encoder (x )
for transformer_layer in self .transformer_layers :
x = transformer_layer (x )
return x
.
Add view as alternative to unfold for shaping to vit format,
# (N C IMG_H IMG_W) -> (N C PATCH_COUNT_H PATCH_COUNT_W PATCH_SIZE_H PATCH_SIZE_W)
# (64 3 72 72) -> (64 3 12 12 6 6) -> (64 12 12 6 6 3) -> (64 144 108)
patches = img .unfold (2 , self .patch_size , self .patch_size ).unfold (3 , self .patch_size , self .patch_size )
patches = patches .permute (0 , 2 , 3 , 4 , 5 , 1 )
patches = patches .reshape (self .batch_size , self .patch_count , - 1 ) # TODO: torch.view
.
Reactions are currently unavailable
You can’t perform that action at this time.
paper-implementations/pytorch/vision_transformer/vit.py
Lines 95 to 96 in 10ed370
paper-implementations/pytorch/vision_transformer/vit.py
Lines 128 to 142 in 10ed370
paper-implementations/pytorch/vision_transformer/vit.py
Line 118 in 10ed370
paper-implementations/pytorch/vision_transformer/vit.py
Lines 145 to 174 in 10ed370
viewas alternative tounfoldfor shaping to vit format,paper-implementations/pytorch/vision_transformer/vit.py
Lines 270 to 274 in 10ed370