Skip to content

Pytorch后端NHWC和NCHW问题 #69

@Windaway

Description

@Windaway

New Issue Checklist

Issue Description

Pytorch后端模型定义NHWC和NCHW数据格式主要是定以数据和模型后传到设备时用.to("cuda:0", memory_format=torch.channels_last)确定。

TLX目前做法是pytorch依据nhwc格式时,全部转NCHW然后处理完转回来,这潜在是让模型用NCHW格式计算。对纯GPU应用时问题不大,但是对于一些NHWC友好的设备部署,比如未来的Mindspore,由于多次nhwc nchw切换,性能有损失。

这里可能需要框架对于pytorch这里nhwc支持改成全局变量,即输入时数据做nchw-nhwc,模型转nhwc然后计算即可。

不过Pytorch本身GPU NHWC支持稀烂,倒不是很急。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions