Skip to content

Commit

Permalink
Update forced_align method to only support batch Tensors (#3433)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3433

Current design of forced_align accept 2D Tensor for `log_probs` and 1D Tensor for `targets`. To make the API simple, the PR make changes to only support batch Tensors (3D Tensor for `log_probs` and 2D Tensor for `targets`).

Differential Revision: D46657526

fbshipit-source-id: 3d3a762d91259cb87f846bc4981ea210c2bf2b38
  • Loading branch information
nateanl authored and facebook-github-bot committed Jun 12, 2023
1 parent c587715 commit 572e610
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 124 deletions.
34 changes: 17 additions & 17 deletions examples/tutorials/ctc_forced_alignment_api_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1)

emission = emissions[0].cpu().detach()
emission = emissions.cpu().detach()
dictionary = {c: i for i, c in enumerate(labels)}

print(dictionary)
Expand All @@ -108,7 +108,7 @@
# ^^^^^^^^^^^^^
#

plt.imshow(emission.T)
plt.imshow(emission[0].T)
plt.colorbar()
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
Expand Down Expand Up @@ -206,27 +206,27 @@ def compute_alignments(transcript, dictionary, emission):
frames = []
tokens = [dictionary[c] for c in transcript.replace(" ", "")]

targets = torch.tensor(tokens, dtype=torch.int32)
input_lengths = torch.tensor(emission.shape[0])
target_lengths = torch.tensor(targets.shape[0])
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
input_lengths = torch.tensor(emission.shape[0:2])
target_lengths = torch.tensor(targets.shape)

# This is the key step, where we call the forced alignment API functional.forced_align to compute alignments.
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)

assert len(frame_alignment) == input_lengths.item()
assert len(targets) == target_lengths.item()
assert len(frame_alignment) == input_lengths[0].item()
assert len(targets) == target_lengths[0].item()

token_index = -1
prev_hyp = 0
for i in range(len(frame_alignment)):
if frame_alignment[i].item() == 0:
for i in range(frame_alignment.shape[1]):
if frame_alignment[0][i].item() == 0:
prev_hyp = 0
continue

if frame_alignment[i].item() != prev_hyp:
if frame_alignment[0][i].item() != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, frame_scores[i].exp().item()))
prev_hyp = frame_alignment[i].item()
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item()))
prev_hyp = frame_alignment[0][i].item()
return frames, frame_alignment, frame_scores


Expand Down Expand Up @@ -427,7 +427,7 @@ def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i, waveform, word_segments, frame_alignment):
ratio = waveform.size(1) / len(frame_alignment)
ratio = waveform.size(1) / frame_alignment.size(1)
word = word_segments[i]
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
Expand Down Expand Up @@ -549,12 +549,12 @@ def get_emission(waveform):

emissions, _ = model(waveform)
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()
emission = emissions.cpu().detach()

# Append the extra dimension corresponding to the <star> token
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions, extra_dim), 2)
emission = emissions[0].cpu().detach()
emission = emissions.cpu().detach()
return emission, waveform


Expand Down Expand Up @@ -597,14 +597,14 @@ def get_emission(waveform):
"x": 30,
"*": 31,
}
assert len(dictionary) == emission.shape[1]
assert len(dictionary) == emission.shape[2]


def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission)
segments = merge_repeats(frames, transcript)
word_segments = merge_words(transcript, segments)
plot_alignments(segments, word_segments, waveform[0], emission.shape[0])
plot_alignments(segments, word_segments, waveform[0], emission.shape[1])
plt.show()
return word_segments, frame_alignment

Expand Down
85 changes: 46 additions & 39 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,55 +1116,60 @@ def test_preemphasis_deemphasis_roundtrip(self, input_shape, coeff):

@parameterized.expand(
[
([0, 1, 1, 0], [0, 1, 5, 1, 0], torch.int32),
([0, 1, 2, 3, 4], [0, 1, 2, 3, 4], torch.int32),
([3, 3, 3], [3, 5, 3, 5, 3], torch.int64),
([0, 1, 2], [0, 1, 1, 1, 2], torch.int64),
([[0, 1, 1, 0]], [[0, 1, 5, 1, 0]], torch.int32),
([[0, 1, 2, 3, 4]], [[0, 1, 2, 3, 4]], torch.int32),
([[3, 3, 3]], [[3, 5, 3, 5, 3]], torch.int64),
([[0, 1, 2]], [[0, 1, 1, 1, 2]], torch.int64),
]
)
def test_forced_align(self, targets, ref_path, targets_dtype):
emission = torch.tensor(
[
[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107],
[
[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107],
]
],
dtype=self.dtype,
device=self.device,
)
blank = 5
batch_index = 0
ref_path = torch.tensor(ref_path, dtype=targets_dtype, device=self.device)
ref_scores = torch.tensor(
[torch.log(emission[i, ref_path[i]]).item() for i in range(emission.shape[0])],
[torch.log(emission[batch_index, i, ref_path[batch_index, i]]).item() for i in range(emission.shape[1])],
dtype=emission.dtype,
device=self.device,
)
).unsqueeze(0)
log_probs = torch.log(emission)
targets = torch.tensor(targets, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]))
target_lengths = torch.tensor((targets.shape[0]))
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
assert hyp_path.shape == ref_path.shape
assert hyp_scores.shape == ref_scores.shape
self.assertEqual(hyp_path, ref_path)
self.assertEqual(hyp_scores, ref_scores)

@parameterized.expand([(torch.int32,), (torch.int64,)])
def test_forced_align_fail(self, targets_dtype):
log_probs = torch.rand(5, 6, dtype=self.dtype, device=self.device)
targets = torch.tensor([0, 1, 2, 3, 4, 4], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(1, 5, 6, dtype=self.dtype, device=self.device)
targets = torch.tensor([[0, 1, 2, 3, 4, 4]], dtype=targets_dtype, device=self.device)
blank = 5
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"targets length is too long for CTC"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

targets = torch.tensor([5, 3, 3], dtype=targets_dtype, device=self.device)
targets = torch.tensor([[5, 3, 3]], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets Tensor shouldn't contain blank index"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

log_probs = log_probs.int()
targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device)
targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"log_probs must be float64, float32 or float16"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

Expand All @@ -1175,40 +1180,42 @@ def test_forced_align_fail(self, targets_dtype):

log_probs = torch.rand(3, 4, 6, dtype=self.dtype, device=self.device)
targets = targets.int()
with self.assertRaisesRegex(RuntimeError, r"3-D tensor is not yet supported for log_probs"):
with self.assertRaisesRegex(
RuntimeError, r"The batch dimension for log_probs must be 1 at the current version"
):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

targets = torch.randint(0, 4, (3, 4), device=self.device)
log_probs = torch.rand(3, 6, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"2-D tensor is not yet supported for targets"):
log_probs = torch.rand(1, 3, 6, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, r"The batch dimension for targets must be 1 at the current version"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

targets = torch.tensor([0, 1, 2, 3], dtype=targets_dtype, device=self.device)
input_lengths = torch.randint(1, 5, (3,), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 0-D"):
targets = torch.tensor([[0, 1, 2, 3]], dtype=targets_dtype, device=self.device)
input_lengths = torch.randint(1, 5, (3, 5), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input_lengths must be 1-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.randint(1, 5, (3,), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 0-D"):
input_lengths = torch.tensor([log_probs.shape[0]], device=self.device)
target_lengths = torch.randint(1, 5, (3, 5), device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target_lengths must be 1-D"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

input_lengths = torch.tensor((10000), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
input_lengths = torch.tensor([10000], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"input length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

input_lengths = torch.tensor((log_probs.shape[0]))
target_lengths = torch.tensor((10000))
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([10000], device=self.device)
with self.assertRaisesRegex(RuntimeError, r"target length mismatch"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

targets = torch.tensor([7, 8, 9, 10], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(10, 5, dtype=self.dtype, device=self.device)
targets = torch.tensor([[7, 8, 9, 10]], dtype=targets_dtype, device=self.device)
log_probs = torch.rand(1, 10, 5, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, r"targets values must be less than the CTC dimension"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)

targets = torch.tensor([1, 3, 3], dtype=targets_dtype, device=self.device)
targets = torch.tensor([[1, 3, 3]], dtype=targets_dtype, device=self.device)
blank = 10000
with self.assertRaisesRegex(RuntimeError, r"blank must be within \[0, num classes\)"):
hyp_path, hyp_scores = F.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
Expand Down Expand Up @@ -1238,14 +1245,14 @@ class FunctionalCUDAOnly(TestBaseMixin):
@nested_params(
[torch.half, torch.float, torch.double],
[torch.int32, torch.int64],
[(50, 100), (100, 100)],
[(10,), (40,), (45,)],
[(1, 50, 100), (1, 100, 100)],
[(1, 10), (1, 40), (1, 45)],
)
def test_forced_align_same_result(self, log_probs_dtype, targets_dtype, log_probs_shape, targets_shape):
log_probs = torch.rand(log_probs_shape, dtype=log_probs_dtype, device=self.device)
targets = torch.randint(1, 100, targets_shape, dtype=targets_dtype, device=self.device)
input_lengths = torch.tensor((log_probs.shape[0]), device=self.device)
target_lengths = torch.tensor((targets.shape[0]), device=self.device)
input_lengths = torch.tensor([log_probs.shape[1]], device=self.device)
target_lengths = torch.tensor([targets.shape[1]], device=self.device)
log_probs_cuda = log_probs.cuda()
targets_cuda = targets.cuda()
input_lengths_cuda = input_lengths.cuda()
Expand Down
Loading

0 comments on commit 572e610

Please sign in to comment.