Skip to content

Commit

Permalink
Fix StreamWriter regression around RGB0/BGR0
Browse files Browse the repository at this point in the history
- Add RGB0/BGR0 support to CPU encoder
- Allow to pass RGB/BGR when expectged format is RGB0/BGR0
  • Loading branch information
mthrok committed Jul 5, 2023
1 parent 163157d commit fba18b8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 18 deletions.
25 changes: 11 additions & 14 deletions torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,33 +101,30 @@ enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
".");
}

const std::set<AVPixelFormat> SUPPORTED_PIX_FMTS{
AV_PIX_FMT_GRAY8,
AV_PIX_FMT_RGB0,
AV_PIX_FMT_BGR0,
AV_PIX_FMT_RGB24,
AV_PIX_FMT_BGR24,
AV_PIX_FMT_YUV444P};

enum AVPixelFormat get_src_pix_fmt(const std::string& src) {
AVPixelFormat fmt = FFMPEG av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
default:;
}
TORCH_CHECK(
false,
SUPPORTED_PIX_FMTS.count(fmt),
"Unsupported pixel format (",
src,
") was provided. Valid values are ",
[]() -> std::string {
std::vector<std::string> ret;
for (const auto& fmt :
{AV_PIX_FMT_GRAY8,
AV_PIX_FMT_RGB24,
AV_PIX_FMT_BGR24,
AV_PIX_FMT_YUV444P}) {
for (const auto& fmt : SUPPORTED_PIX_FMTS) {
ret.emplace_back(FFMPEG av_get_pix_fmt_name(fmt));
}
return c10::Join(", ", ret);
}(),
".");
return fmt;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
54 changes: 50 additions & 4 deletions torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
namespace torchaudio::io {
namespace {

using namespace torch::indexing;

using InitFunc = TensorConverter::InitFunc;
using ConvertFunc = TensorConverter::ConvertFunc;

Expand Down Expand Up @@ -111,6 +113,28 @@ void validate_video_input(
t.sizes());
}

// Special case where encode pixel format is RGB0/BGR0 but the tensor is RGB/BGR
void validate_rgb0(const torch::Tensor& t, AVFrame* buffer) {
if (buffer->hw_frames_ctx) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
} else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
}
TORCH_CHECK(
t.dtype().toScalarType() == c10::ScalarType::Byte,
"Expected Tensor of uint8 type.");

TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D.");
TORCH_CHECK(
t.size(2) == buffer->height && t.size(3) == buffer->width,
"Expected tensor with shape (N, 3, ",
buffer->height,
", ",
buffer->width,
") (NCHW format). Found ",
t.sizes());
}

// NCHW ->NHWC, ensure contiguous
torch::Tensor init_interlaced(const torch::Tensor& tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.dim() == 4);
Expand Down Expand Up @@ -276,16 +300,20 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data);
auto sw_pix_fmt = frames_ctx->sw_format;
switch (sw_pix_fmt) {
// Note:
// RGB0 / BGR0 expects 4 channel, but neither
// av_pix_fmt_desc_get(pix_fmt)->nb_components
// or av_pix_fmt_count_planes(pix_fmt) returns 4.
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: {
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_interlaced_video_cuda(t, f, 4);
};
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
// Special treatment for the case user pass regular RGB/BGR tensor.
if (t.dim() == 4 && t.size(1) == 3) {
validate_rgb0(t, f);
auto tmp =
torch::empty({t.size(0), t.size(2), t.size(3), 4}, t.options());
tmp.index_put_({"...", Slice(0, 3)}, t.permute({0, 2, 3, 1}));
return tmp;
}
validate_video_input(t, f, 4);
return init_interlaced(t);
};
Expand Down Expand Up @@ -327,6 +355,24 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
};
return {init_func, convert_func};
}
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: {
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
if (t.dim() == 4 && t.size(1) == 3) {
validate_rgb0(t, f);
auto tmp =
torch::empty({t.size(0), t.size(2), t.size(3), 4}, t.options());
tmp.index_put_({"...", Slice(0, 3)}, t.permute({0, 2, 3, 1}));
return tmp;
}
validate_video_input(t, f, 4);
return init_interlaced(t);
};
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_interlaced_video(t, f, 4);
};
return {init_func, convert_func};
}
case AV_PIX_FMT_YUV444P: {
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
validate_video_input(t, f, 3);
Expand Down

0 comments on commit fba18b8

Please sign in to comment.