Skip to content

Fix ViT implementation mistakes according to feedback. #6

@quickgrid

Description

@quickgrid
  • Fix skip connection,
    x_skip = self.layer_norm(x)
    attn_output, attn_output_weights = self.multihead_attn(query=x, key=x, value=x)
    .
  • Look into class token, mean pooling for classification,
    class ClassificationMLP(nn.Module):
    def __init__(
    self,
    num_classes: int,
    embedding_dim: int,
    patch_count: int,
    ) -> None:
    super(ClassificationMLP, self).__init__()
    self.flatten = nn.Flatten()
    self.class_mapping_head = nn.Linear(in_features=embedding_dim * patch_count, out_features=num_classes)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.flatten(x)
    x = self.class_mapping_head(x)
    return x
    .
  • Look into torch parameter for,
    self.positions = torch.arange(start=0, end=self.num_patches, step=1, dtype=torch.int, device=self.device)
    .
  • Merge as sequential,
    class VisionTransformerModel(nn.Module):
    def __init__(
    self,
    num_heads: int,
    embedding_dim: int,
    transformer_layers_count: int,
    num_patches: int,
    device: torch.device,
    ) -> None:
    super(VisionTransformerModel, self).__init__()
    self.transformer_layers_count = transformer_layers_count
    self.transformer_layers = nn.ModuleList()
    for _ in range(transformer_layers_count):
    self.transformer_layers.append(
    TransformerEncoderModel(num_heads=num_heads, embedding_dim=embedding_dim),
    )
    self.patch_encoder = PatchPositionEncoder(
    num_patches=num_patches,
    embedding_dim=embedding_dim,
    device=device,
    )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.patch_encoder(x)
    for transformer_layer in self.transformer_layers:
    x = transformer_layer(x)
    return x
    .
  • Add view as alternative to unfold for shaping to vit format,
    # (N C IMG_H IMG_W) -> (N C PATCH_COUNT_H PATCH_COUNT_W PATCH_SIZE_H PATCH_SIZE_W)
    # (64 3 72 72) -> (64 3 12 12 6 6) -> (64 12 12 6 6 3) -> (64 144 108)
    patches = img.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
    patches = patches.permute(0, 2, 3, 4, 5, 1)
    patches = patches.reshape(self.batch_size, self.patch_count, -1) # TODO: torch.view
    .

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions