diff --git a/k2/csrc/intersect_dense.cu b/k2/csrc/intersect_dense.cu index fbb53ca54..7291a3a3f 100644 --- a/k2/csrc/intersect_dense.cu +++ b/k2/csrc/intersect_dense.cu @@ -778,7 +778,7 @@ class MultiGraphDenseIntersect { void DoStep(int32_t t) { NVTX_RANGE(K2_FUNC); Step &step = steps_[t], &prev_step = steps_[t - 1]; - int32_t scores_num_cols = b_fsas_.scores.Dim1(); + int32_t scores_num_cols = b_fsas_.scores.Dim1(); const float minus_inf = -std::numeric_limits::infinity(); // Divide by two because each arc is repeated twice in arc_scores (once for @@ -814,9 +814,10 @@ class MultiGraphDenseIntersect { backward_dest_prob = prev_state_scores_data[dest_state_scores_index_backward]; - // Assign negative infinity (-inf) to both the forward and backward scores, - // if the label on the carc is out-of-range, i.e., the label in the decoding - // graph (a_fsas) does not exist in the neural-net output (b_fsas). + // Assign negative infinity (-inf) to both the forward and backward + // scores, if the label on the carc is out-of-range, i.e., the label + // in the decoding graph (a_fsas) does not exist in the neural-net + // output (b_fsas). float b_score_forward; float b_score_backward; if (carc.label_plus_one <= scores_num_cols) { diff --git a/k2/csrc/rnnt_decode.cu b/k2/csrc/rnnt_decode.cu index 7e44391bd..4c3b18681 100644 --- a/k2/csrc/rnnt_decode.cu +++ b/k2/csrc/rnnt_decode.cu @@ -248,7 +248,7 @@ RaggedShape RnntDecodingStreams::ExpandArcs() { return unpruned_arcs_shape; } -Renumbering RnntDecodingStreams::DoFisrtPassPruning( +Renumbering RnntDecodingStreams::DoFirstPassPruning( RaggedShape &unpruned_arcs_shape, const Array2 &logprobs) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(unpruned_arcs_shape.NumAxes(), 4); @@ -439,7 +439,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { auto unpruned_arcs_shape = ExpandArcs(); // (2) Do initial pruning. - auto pass1_renumbering = DoFisrtPassPruning(unpruned_arcs_shape, logprobs); + auto pass1_renumbering = DoFirstPassPruning(unpruned_arcs_shape, logprobs); // pass1_arcs_shape has a shape of [stream][context][state][arc] auto pass1_arcs_shape = @@ -489,6 +489,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { const auto logprobs_acc = logprobs.Accessor(); const Arc *const *graphs_arcs_data = graphs_.values.Data(); + K2_EVAL( c_, cur_num_arcs, lambda_populate_arcs_states_scores, (int32_t arc_idx) { // Init renumber_arcs to 0, place here to save one kernel. @@ -508,6 +509,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { idx01 = uas_row_ids2_data[idx012], idx0 = uas_row_ids1_data[idx01], num_graph_states = num_graph_states_data[idx0]; int64_t this_state = this_states_values_data[idx012]; + int32_t this_graph_state = this_state % num_graph_states; double this_score = this_scores_data[idx012]; // handle the implicit epsilon self-loop @@ -516,7 +518,21 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { // we assume termination symbol to be 0 here. scores_data[arc_idx] = this_score + logprobs_acc(idx01, 0); ArcInfo ai; - ai.graph_arc_idx01 = -1; + /* + Track state index for self-loop arcs. + It's lucky that type int32_t has range [-2147483648, 2147483647] + there is one more negative values than positive values in computer. + state (0) --> graph_arc_idx01 (-1) + state (1) --> graph_arc_idx01 (-2) + state (2) --> graph_arc_idx01 (-3) + state (2147483647) --> graph_arc_idx01 (-2147483648) + + Actually, super final state has no self-loop. + So definitely there are enough negative values + to represent positive state index. + */ + ai.graph_arc_idx01 = -(this_graph_state + 1); + K2_CHECK_LT(ai.graph_arc_idx01, 0); ai.score = logprobs_acc(idx01, 0); ai.label = 0; arcs_data[arc_idx] = ai; @@ -527,8 +543,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; int64_t this_context_state = this_state / num_graph_states; - int32_t this_graph_state = this_state % num_graph_states, - graph_idx0x = graph_row_split1_data[this_graph_state], + int32_t graph_idx0x = graph_row_split1_data[this_graph_state], graph_idx01 = graph_idx0x + idx3 - 1; // minus 1 here as // epsilon self-loop // takes the position 0. @@ -715,6 +730,162 @@ void RnntDecodingStreams::GatherPrevFrames( } } +void RnntDecodingStreams::GetFinalArcs() { + NVTX_RANGE(K2_FUNC); + /* + This function handles last two steps of the generated lattice. + Relationship of variables in these two steps are: + + arcs: last frame arcs final arcs + states: {last frame state} ---------------> {final states} ---------> {super final state} # noqa + + Suer final state has no leaving arcs. + */ + + int32_t frames = prev_frames_.size(); + + // with shape [stream][context][state][arc] + auto last_frame_shape = prev_frames_[frames - 1]->shape; + + // Note: last_frame_arc_data is non-const + // The original "dest_state" attribute for each element in last_frame_arc_data + // is state index processed by function GroupStatesByContexts. + // In this function, source states in last_frame is expanded again, + // and those expanded destination states are NOT grouped to save time. + // So "dest_state" should be re-assigned to a new value. + ArcInfo *last_frame_arc_data = prev_frames_[frames - 1]->values.Data(); + const int32_t *lfs_row_ids3_data = last_frame_shape.RowIds(3).Data(), + *lfs_row_ids2_data = last_frame_shape.RowIds(2).Data(), + *lfs_row_ids1_data = last_frame_shape.RowIds(1).Data(), + *lfs_row_splits3_data = last_frame_shape.RowSplits(3).Data(), + *lfs_row_splits2_data = last_frame_shape.RowSplits(2).Data(), + *lfs_row_splits1_data = last_frame_shape.RowSplits(1).Data(); + + const int32_t *num_graph_states_data = num_graph_states_.Data(); + const int32_t *const *graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); + const Arc *const *graphs_arcs_data = graphs_.values.Data(); + + // Name meaning of final_grpah_states: + // "final_" means it's for "final states". + // "_graph_states" means it storages state index in decoding graph. + // Though this variable could be calculated both in + // labmda_get_final_arcs_shape and lambda_populate_final_arcs, + // to save time, its calculated and cached during the former and + // used in the later. + Array1 final_graph_states(c_, last_frame_shape.NumElements()); + int32_t* final_graph_states_data = final_graph_states.Data(); + + // Calculate num_arcs for each final state. + Array1 num_final_arcs(c_, last_frame_shape.NumElements() + 1); + int32_t *num_final_arcs_data = num_final_arcs.Data(); + + K2_EVAL( + c_, last_frame_shape.NumElements(), lambda_get_final_arcs_shape, + (int32_t idx0123) { + // place here to save one kernel. + num_final_arcs_data[idx0123] = 0; + + int32_t idx012 = lfs_row_ids3_data[idx0123], // state_idx012 + idx01 = lfs_row_ids2_data[idx012], // context_idx01 + idx0 = lfs_row_ids1_data[idx01], // stream_idx0 + arc_idx01x = lfs_row_splits2_data[idx01], + arc_idx01xx = lfs_row_splits3_data[arc_idx01x], + arc_idx23 = idx0123 - arc_idx01xx; + + ArcInfo& ai = last_frame_arc_data[idx0123]; + + // Re-assign dest_state to a new value. + // See more detail comment at previous last_frame_arc_data definition. + ai.dest_state = arc_idx23; + + if (ai.label == -1) { + // -(num_graph_states_data[idx0]) for state not expandable. + final_graph_states_data[idx0123] = -(num_graph_states_data[idx0]); + return; + } + int32_t dest_state = -1; + const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; + const Arc *graph_arcs_data = graphs_arcs_data[idx0]; + if (ai.graph_arc_idx01 < 0) { + // For implicit self-loop arcs. + dest_state = -ai.graph_arc_idx01 - 1; + K2_CHECK_GE(dest_state, 0); + K2_CHECK_LE(dest_state, num_graph_states_data[idx0]); + } else { + // For other arcs shown in the decoding graph. + dest_state = graph_arcs_data[ai.graph_arc_idx01].dest_state; + } + K2_CHECK_GE(dest_state, 0); + + final_graph_states_data[idx0123] = dest_state; + // Plus one for the implicit epsilon self-loop. + num_final_arcs_data[idx0123] = graph_row_split1_data[dest_state + 1] - + graph_row_split1_data[dest_state] + 1; + }); + + + ExclusiveSum(num_final_arcs, &num_final_arcs); + + auto final_arcs_shape = RaggedShape2(&num_final_arcs, nullptr, -1); + final_arcs_shape = ComposeRaggedShapes(last_frame_shape, final_arcs_shape); + // [steam][context][state][arc][arc] --> [stream][context][arc][arc] + // could be viewd as [strem][context][final state][arc] + final_arcs_shape = RemoveAxis(final_arcs_shape, 2); + const int32_t *fas_row_ids1_data = final_arcs_shape.RowIds(1).Data(), + *fas_row_ids2_data = final_arcs_shape.RowIds(2).Data(), + *fas_row_ids3_data = final_arcs_shape.RowIds(3).Data(), + *fas_row_splits3_data = final_arcs_shape.RowSplits(3).Data(); + + auto final_arcs = Ragged(final_arcs_shape); + ArcInfo *final_arcs_data = final_arcs.values.Data(); + + K2_EVAL( + c_, final_arcs_shape.NumElements(), lambda_populate_final_arcs, + (int32_t idx0123) { + const int32_t idx012 = fas_row_ids3_data[idx0123], // state + idx01 = fas_row_ids2_data[idx012], // context + idx0 = fas_row_ids1_data[idx01], // stream + idx012x = fas_row_splits3_data[idx012], + arc_idx3 = idx0123 - idx012x; + + const Arc *graph_arcs_data = graphs_arcs_data[idx0]; + const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; + int32_t graph_state_idx0 = final_graph_states_data[idx012]; + + int32_t ai_graph_arc_idx01 = 0; + int32_t ai_arc_label = 0; + if (graph_state_idx0 < 0) { + /* + Could be one of following two cases: + case 1: not expandable if graph_state_idx0 == -(num_graph_states_data[idx0]) # noqa + case 2: implicit self-loop if graph_state_idx0 > -(num_graph_states_data[idx0]) # noqa + */ + K2_DCHECK_GT(graph_state_idx0, -(num_graph_states_data[idx0])); + ai_arc_label = 0; + ai_graph_arc_idx01 = -1; + } else { + // For arcs shown in decoding graph. + int32_t graph_arc_idx0x = graph_row_split1_data[graph_state_idx0]; + // arc_idx2 could be viewed as graph_arc_idx1, + // since final_arcs_shape has 3 axes where arc_idx2 is calculated, + // while decoding_graph only has 2 axes where arc_idx2 is used. + ai_graph_arc_idx01 = graph_arc_idx0x + arc_idx3; + auto graph_arc = graph_arcs_data[ai_graph_arc_idx01]; + ai_arc_label = graph_arc.label; + } + ArcInfo ai; + // ai.dest_state will be overwritted by FormatOutput + // just initialize it as -1 here + ai.dest_state = -1; + ai.graph_arc_idx01 = ai_graph_arc_idx01; + ai.score = 0.0; + ai.label = ai_arc_label; + final_arcs_data[idx0123] = ai; + }); + + prev_frames_.emplace_back(std::make_shared>(final_arcs)); +} + void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, bool allow_partial, FsaVec *ofsa, Array1 *out_map) { @@ -736,6 +907,8 @@ void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, GatherPrevFrames(num_frames); + GetFinalArcs(); + int32_t frames = prev_frames_.size(); auto last_frame_shape = prev_frames_[frames - 1]->shape; @@ -888,7 +1061,7 @@ void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, K2_EVAL( c_, num_streams_, lambda_set_start_offset, (int32_t stream_idx) { num_padded_frames_data[stream_idx] = - frames - num_padded_frames_data[stream_idx]; + frames - num_padded_frames_data[stream_idx] - 1; K2_CHECK_LE(0, num_padded_frames_data[stream_idx]); }); } @@ -946,17 +1119,18 @@ void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, int32_t dest_state_idx012 = oarc_idx01xx_next + arc_info.dest_state; arc.dest_state = dest_state_idx012 - oarc_idx0xxx; - // graph_arc_idx01 == -1 means this is a implicit epsilon self-loop + // graph_arc_idx01 < 0 means this is an implicit epsilon self-loop // arc_info.label == -1 means this is the final arc before last // frame this is non-accessible arc, we set its label to 0 here to // make the generated lattice a valid k2 fsa. - if (arc_info.graph_arc_idx01 == -1 || arc_info.label == -1) { + if (arc_info.graph_arc_idx01 <= -1 || arc_info.label == -1) { arc.label = 0; + out_map_data[oarc_idx01234] = -1; } else { arc.label = graph_arcs_data[arc_info.graph_arc_idx01].label; + out_map_data[oarc_idx01234] = arc_info.graph_arc_idx01; } arc.score = arc_info.score; - out_map_data[oarc_idx01234] = arc_info.graph_arc_idx01; } arcs_out_data[oarc_idx01234] = arc; if (arc_map_b != nullptr) { diff --git a/k2/csrc/rnnt_decode.h b/k2/csrc/rnnt_decode.h index 3e2cd0fcd..c34f90cd3 100644 --- a/k2/csrc/rnnt_decode.h +++ b/k2/csrc/rnnt_decode.h @@ -94,10 +94,25 @@ struct RnntDecodingConfig { struct ArcInfo { // The arc-index within the RnntDecodingStream::graph that corresponds to this - // arc, or -1 if this arc is a "termination symbol" (these do not appear in - // the graph). + // arc if non-negative. + // There is an implicit self-loop arc for each state, which are represented + // by -(state_index + 1), see following comments of dest_state_in_graph. int32_t graph_arc_idx01; + // Note: + // 1. To save memory, value of this variable is calculated + // from graph_arc_idx01. + // 2. It is differnt from variable dest_state. + // dest_state_in_graph is the destination state index in decoding graph. + // dest_state below is the state index in "generated lattice". + // There are two kinds of arcs in decoding graph: + // 1. Implicit self-loop arcs, dest_state of these arcs are calculated + // with -(graph_arc_idx01 + 1). + // (Note, graph_arc_idx01 is negative for these arcs) + // 2. Other arcs shown in decoding graph, dest_state of these arcs are + // calculated with graph_arcs_data[ai.graph_arc_idx01].dest_state + // int32_t dest_state_in_graph; + // The score on the arc; contains both the graph score (if any) and the score // from the RNN-T joiner. float score; @@ -220,6 +235,38 @@ class RnntDecodingStreams { void FormatOutput(const std::vector &num_frames, bool allow_partial, FsaVec *ofsa, Array1 *out_map); + /* + Generate the lattice. + Note: Almost the same with previous overloaded version, + except for an extra `is_final` argument. + + Note: The prev_frames_ only contains decoded by current object, in order to + generate the lattice we will first gather all the previous frames from + individual streams. + + @param [in] num_frames A vector containing the number of frames we want + to gather for each stream (note: the frames we have + ever received). + It MUST satisfy `num_frames.size() == num_streams_`, and + `num_frames[i] <= srcs_[i].prev_frames.size()`. + @param [in] allow_partial If true and there is no final state active, + we will treat all the states on the last frame + to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. + @param [in] is_final If true, function GetFinalArcs() will be called. + If false, the same with previous overloaded version. + @param [out] ofsa The output lattice will write to here, its num_axes + equals to 3, will be re-allocated. + @param [out] out_map It is an Array1 with Dim() equals to + ofsa.NumElements() containing the idx01 into the graph of + each individual streams, mapping current arc in ofsa to + original decoding graphs. It may contain -1 which means + this arc is a "termination symbol". + */ + void FormatOutput(const std::vector &num_frames, bool allow_partial, + bool is_final, FsaVec *ofsa, Array1 *out_map); + /* Terminate the decoding process of current RnntDecodingStreams object, it will update the states & scores of each individual stream and split & @@ -282,8 +329,30 @@ class RnntDecodingStreams { @return Return the renumbering object indicating which arc will be kept. */ - Renumbering DoFisrtPassPruning(RaggedShape &unprund_arcs_shape, + Renumbering DoFirstPassPruning(RaggedShape &unprund_arcs_shape, const Array2 &logprobs); + + /* + Get final arcs when last frame is received, i.e. passing is_final=True to + function `FormatOutput`. + Comparing with openfst, a valid fsa in k2 needs arcs with label==-1 + pointing to a super final state. This function is handling these arcs. + See detail of the problem solved by this function at + https://github.com/k2-fsa/k2/pull/1089 + + If we name varialbes for last two steps of a lattice as: + arcs: last frame arcs final arcs + states: {last frame state} ---------------> {final states} ---------> {super final state} + + This function mainly do following steps: + 1. get last_frame from prev_frames_ + 2. expand last frame and get final states + 3. re-assign dest state of last frame arcs to final states + 4. populate final arcs + 5. append final arcs to prev_frames_ + */ + void GetFinalArcs(); + /* Group states by contexts.