Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix StreamWriter regression around RGB0/BGR0 #3428

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,33 +96,30 @@ enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
".");
}

const std::set<AVPixelFormat> SUPPORTED_PIX_FMTS{
AV_PIX_FMT_GRAY8,
AV_PIX_FMT_RGB0,
AV_PIX_FMT_BGR0,
AV_PIX_FMT_RGB24,
AV_PIX_FMT_BGR24,
AV_PIX_FMT_YUV444P};

enum AVPixelFormat get_src_pix_fmt(const std::string& src) {
AVPixelFormat fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
default:;
}
TORCH_CHECK(
false,
SUPPORTED_PIX_FMTS.count(fmt),
"Unsupported pixel format (",
src,
") was provided. Valid values are ",
[]() -> std::string {
std::vector<std::string> ret;
for (const auto& fmt :
{AV_PIX_FMT_GRAY8,
AV_PIX_FMT_RGB24,
AV_PIX_FMT_BGR24,
AV_PIX_FMT_YUV444P}) {
for (const auto& fmt : SUPPORTED_PIX_FMTS) {
ret.emplace_back(av_get_pix_fmt_name(fmt));
}
return c10::Join(", ", ret);
}(),
".");
return fmt;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
54 changes: 50 additions & 4 deletions torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ namespace torchaudio::io {

namespace {

using namespace torch::indexing;

using InitFunc = TensorConverter::InitFunc;
using ConvertFunc = TensorConverter::ConvertFunc;

Expand Down Expand Up @@ -111,6 +113,28 @@ void validate_video_input(
t.sizes());
}

// Special case where encode pixel format is RGB0/BGR0 but the tensor is RGB/BGR
void validate_rgb0(const torch::Tensor& t, AVFrame* buffer) {
if (buffer->hw_frames_ctx) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
} else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
}
TORCH_CHECK(
t.dtype().toScalarType() == c10::ScalarType::Byte,
"Expected Tensor of uint8 type.");

TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D.");
TORCH_CHECK(
t.size(2) == buffer->height && t.size(3) == buffer->width,
"Expected tensor with shape (N, 3, ",
buffer->height,
", ",
buffer->width,
") (NCHW format). Found ",
t.sizes());
}

// NCHW ->NHWC, ensure contiguous
torch::Tensor init_interlaced(const torch::Tensor& tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.dim() == 4);
Expand Down Expand Up @@ -276,16 +300,20 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data);
auto sw_pix_fmt = frames_ctx->sw_format;
switch (sw_pix_fmt) {
// Note:
// RGB0 / BGR0 expects 4 channel, but neither
// av_pix_fmt_desc_get(pix_fmt)->nb_components
// or av_pix_fmt_count_planes(pix_fmt) returns 4.
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: {
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_interlaced_video_cuda(t, f, 4);
};
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
// Special treatment for the case user pass regular RGB/BGR tensor.
if (t.dim() == 4 && t.size(1) == 3) {
validate_rgb0(t, f);
auto tmp =
torch::empty({t.size(0), t.size(2), t.size(3), 4}, t.options());
tmp.index_put_({"...", Slice(0, 3)}, t.permute({0, 2, 3, 1}));
return tmp;
}
validate_video_input(t, f, 4);
return init_interlaced(t);
};
Expand Down Expand Up @@ -327,6 +355,24 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
};
return {init_func, convert_func};
}
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: {
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
if (t.dim() == 4 && t.size(1) == 3) {
validate_rgb0(t, f);
auto tmp =
torch::empty({t.size(0), t.size(2), t.size(3), 4}, t.options());
tmp.index_put_({"...", Slice(0, 3)}, t.permute({0, 2, 3, 1}));
return tmp;
}
validate_video_input(t, f, 4);
return init_interlaced(t);
};
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_interlaced_video(t, f, 4);
};
return {init_func, convert_func};
}
case AV_PIX_FMT_YUV444P: {
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
validate_video_input(t, f, 3);
Expand Down
Loading