From fba18b815b2e59b87e1549f9fe1203743e41b983 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 8 Jun 2023 14:40:55 -0400 Subject: [PATCH] Fix StreamWriter regression around RGB0/BGR0 - Add RGB0/BGR0 support to CPU encoder - Allow to pass RGB/BGR when expectged format is RGB0/BGR0 --- .../ffmpeg/stream_writer/encode_process.cpp | 25 ++++----- .../ffmpeg/stream_writer/tensor_converter.cpp | 54 +++++++++++++++++-- 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp b/torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp index 3f9a153004e..88484f950b2 100644 --- a/torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp +++ b/torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp @@ -101,33 +101,30 @@ enum AVSampleFormat get_src_sample_fmt(const std::string& src) { "."); } +const std::set SUPPORTED_PIX_FMTS{ + AV_PIX_FMT_GRAY8, + AV_PIX_FMT_RGB0, + AV_PIX_FMT_BGR0, + AV_PIX_FMT_RGB24, + AV_PIX_FMT_BGR24, + AV_PIX_FMT_YUV444P}; + enum AVPixelFormat get_src_pix_fmt(const std::string& src) { AVPixelFormat fmt = FFMPEG av_get_pix_fmt(src.c_str()); - switch (fmt) { - case AV_PIX_FMT_GRAY8: - case AV_PIX_FMT_RGB24: - case AV_PIX_FMT_BGR24: - case AV_PIX_FMT_YUV444P: - return fmt; - default:; - } TORCH_CHECK( - false, + SUPPORTED_PIX_FMTS.count(fmt), "Unsupported pixel format (", src, ") was provided. Valid values are ", []() -> std::string { std::vector ret; - for (const auto& fmt : - {AV_PIX_FMT_GRAY8, - AV_PIX_FMT_RGB24, - AV_PIX_FMT_BGR24, - AV_PIX_FMT_YUV444P}) { + for (const auto& fmt : SUPPORTED_PIX_FMTS) { ret.emplace_back(FFMPEG av_get_pix_fmt_name(fmt)); } return c10::Join(", ", ret); }(), "."); + return fmt; } //////////////////////////////////////////////////////////////////////////////// diff --git a/torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.cpp b/torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.cpp index 1478d38d5ac..8d3bacc2cb2 100644 --- a/torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.cpp +++ b/torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.cpp @@ -8,6 +8,8 @@ namespace torchaudio::io { namespace { +using namespace torch::indexing; + using InitFunc = TensorConverter::InitFunc; using ConvertFunc = TensorConverter::ConvertFunc; @@ -111,6 +113,28 @@ void validate_video_input( t.sizes()); } +// Special case where encode pixel format is RGB0/BGR0 but the tensor is RGB/BGR +void validate_rgb0(const torch::Tensor& t, AVFrame* buffer) { + if (buffer->hw_frames_ctx) { + TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA."); + } else { + TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU."); + } + TORCH_CHECK( + t.dtype().toScalarType() == c10::ScalarType::Byte, + "Expected Tensor of uint8 type."); + + TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D."); + TORCH_CHECK( + t.size(2) == buffer->height && t.size(3) == buffer->width, + "Expected tensor with shape (N, 3, ", + buffer->height, + ", ", + buffer->width, + ") (NCHW format). Found ", + t.sizes()); +} + // NCHW ->NHWC, ensure contiguous torch::Tensor init_interlaced(const torch::Tensor& tensor) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.dim() == 4); @@ -276,16 +300,20 @@ std::pair get_video_func(AVFrame* buffer) { auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data); auto sw_pix_fmt = frames_ctx->sw_format; switch (sw_pix_fmt) { - // Note: - // RGB0 / BGR0 expects 4 channel, but neither - // av_pix_fmt_desc_get(pix_fmt)->nb_components - // or av_pix_fmt_count_planes(pix_fmt) returns 4. case AV_PIX_FMT_RGB0: case AV_PIX_FMT_BGR0: { ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) { write_interlaced_video_cuda(t, f, 4); }; InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) { + // Special treatment for the case user pass regular RGB/BGR tensor. + if (t.dim() == 4 && t.size(1) == 3) { + validate_rgb0(t, f); + auto tmp = + torch::empty({t.size(0), t.size(2), t.size(3), 4}, t.options()); + tmp.index_put_({"...", Slice(0, 3)}, t.permute({0, 2, 3, 1})); + return tmp; + } validate_video_input(t, f, 4); return init_interlaced(t); }; @@ -327,6 +355,24 @@ std::pair get_video_func(AVFrame* buffer) { }; return {init_func, convert_func}; } + case AV_PIX_FMT_RGB0: + case AV_PIX_FMT_BGR0: { + InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) { + if (t.dim() == 4 && t.size(1) == 3) { + validate_rgb0(t, f); + auto tmp = + torch::empty({t.size(0), t.size(2), t.size(3), 4}, t.options()); + tmp.index_put_({"...", Slice(0, 3)}, t.permute({0, 2, 3, 1})); + return tmp; + } + validate_video_input(t, f, 4); + return init_interlaced(t); + }; + ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) { + write_interlaced_video(t, f, 4); + }; + return {init_func, convert_func}; + } case AV_PIX_FMT_YUV444P: { InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) { validate_video_input(t, f, 3);