diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 935c4921..2b9547d1 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -369,6 +369,12 @@ defmodule Bumblebee.Text.Generation do if config.forced_token_ids do &forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids) end, + if config.allowed_token_ids != [] do + &allowed_tokens_processor(&1, &2, allowed_token_ids: config.allowed_token_ids) + end, + if config.dfa do + &dfa_processor(&1, &2, dfa: config.dfa) + end, if config.temperature && config.temperature != 1.0 do &temperature_processor(&1, &2, temperature: config.temperature) end diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index eff38e52..7e3a8603 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -3,6 +3,91 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do import Nx.Defn + deftransform dfa_processor(logits, context, opts \\ []) do + opts = Keyword.validate!(opts, [:dfa]) + dfa = opts[:dfa] + + num_states = + Enum.dedup_by(dfa.state_transitions, fn {state, _token_id, _next_state} -> state end) + |> length() + + state_transition_tensor = Nx.broadcast(0, {num_states, Nx.size(logits)}) + + state_transitions_tensor = + for {current_state, token_id, next_state} <- dfa.state_transitions, + reduce: state_transition_tensor do + state_transition_tensor -> + Nx.indexed_put( + state_transition_tensor, + Nx.tensor([current_state, token_id]), + next_state + ) + end + + initial_state = Nx.tensor([0]) |> Nx.vectorize(batch: 1) + + current_state = + find_current_state( + initial_state, + state_transitions_tensor, + context.sequence, + context.input_length, + context.length + ) + + suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) + logits = Nx.select(state_transitions_tensor[current_state], logits, suppressed_logits) + + logits + end + + defn find_current_state( + initial_state, + state_transitions_tensor, + sequence, + input_length, + current_length + ) do + generated_length = current_length - input_length + + last_token_id = sequence[current_length - 1] + token_column = state_transitions_tensor[[.., last_token_id]] |> Nx.squeeze() + + # top_k gives two top values + indices of the column + # if the token is unambiguous, there is only one value != 0 in the column (that's top_values[0]) + # if top_values[1] != 0, there must be two values != 0 in the column, so it's ambiguous + {top_values, top_indices} = Nx.top_k(token_column, k: 2) + + ambiguous_token? = Nx.not_equal(top_values[1], Nx.tensor(0)) + + state = + cond do + generated_length == 0 -> + initial_state + + ambiguous_token? -> + {state, _i, _sequence, _input_length, _generated_length, _states_transitions_tensor} = + while {state = initial_state, i = 0, sequence, input_length, generated_length, + state_transitions_tensor}, + Nx.less(i, generated_length) do + chosen_token = sequence[input_length + i] + new_state = state_transitions_tensor[[state, chosen_token]] + + {new_state, i + 1, sequence, input_length, generated_length, + state_transitions_tensor} + end + + state + + true -> + # we know that top_indices[0] is the row index for the only token id != 0 + # this is our new state! + top_indices[0] + end + + state + end + deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do opts = Keyword.validate!(opts, [:suppressed_token_ids]) @@ -11,6 +96,12 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do Nx.indexed_put(logits, indices, values) end + deftransform allowed_tokens_processor(logits, _context, opts \\ []) do + _opts = Keyword.validate!(opts, [:allowed_token_ids]) + + allow_token_ids(logits, opts[:allowed_token_ids]) + end + defn bos_token_processor(logits, context, opts \\ []) do opts = keyword!(opts, [:bos_token_id]) bos_token_id = opts[:bos_token_id] @@ -113,6 +204,16 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do |> Nx.put_slice([token_id], Nx.tensor([0], type: Nx.type(logits))) end + deftransformp allow_token_ids(logits, allowed_token_ids) do + # Convert allowed_token_ids to a tensor if it's a list + allowed_indices = Nx.tensor(allowed_token_ids) + allowed_logits = Nx.take(logits, allowed_indices) + suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) + + indices = Nx.new_axis(allowed_indices, -1) + Nx.indexed_put(suppressed_logits, indices, allowed_logits) + end + deftransformp ignore_token_id(logits, token_id) do Nx.put_slice( logits, diff --git a/lib/bumblebee/text/generation_config.ex b/lib/bumblebee/text/generation_config.ex index d7a6a9a0..20cfd14d 100644 --- a/lib/bumblebee/text/generation_config.ex +++ b/lib/bumblebee/text/generation_config.ex @@ -93,6 +93,14 @@ defmodule Bumblebee.Text.GenerationConfig do default: [], doc: "a list of token ids to suppress during generation" ], + allowed_token_ids: [ + default: [], + doc: "a list of token ids to enforce during generation (suppressing the all tokens that are not in the list)" + ], + dfa: [ + default: nil, + doc: "the definition of a very simple deterministic finite automaton (dfa) for the generation" + ], no_repeat_ngram_length: [ default: nil, doc: "when set, n-grams of the given length can occur only once in the generated sequence" diff --git a/pair_programming.exs b/pair_programming.exs new file mode 100644 index 00000000..ace3c12a --- /dev/null +++ b/pair_programming.exs @@ -0,0 +1,127 @@ +Mix.install([ + {:bumblebee, path: "../bumblebee_bitcrowd"}, + {:nx, "~> 0.10.0", override: true}, + {:emlx, github: "elixir-nx/emlx"}, + {:benchee, "~> 1.0"} +]) + +Nx.global_default_backend({EMLX.Backend, device: :gpu}) +repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"} +{:ok, model_info} = Bumblebee.load_model(repo, backend: {EMLX.Backend, device: :gpu}) +{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) +{:ok, generation_config} = Bumblebee.load_generation_config(repo) + +sequence_length = 512 + +prompt = """ +Give me an array that contains a mix of numbers and text. +There MUST be at least one number and one text. +Valid examples are: + +["hello",89,"hola",6,4,8] +""" + +numbers = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"] +array_start_token = "[" +array_end_token = "]" +array_addition_token = "," +# String Token would require ! (like "everything, just without ....) +# Token 18 +string_token = "\"" + +# ToDo: should be a list -> idx +states = [ + :starting, + :in_array, + :in_number, + :in_addition, + :in_string, + :end_of_string, + :ending, + :done +] + +state_to_num = fn state -> Enum.find_index(states, & &1 == state) end + +# ------------------------------------- above chars ------------------------------ # +# ------------------------------------- below tokens ------------------------------ # + +array_start_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_start_token) +array_end_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_end_token) +addition_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_addition_token) +string_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, string_token) +end_of_sequence_token_id = Bumblebee.Tokenizer.special_token_id(tokenizer, :eos) + +special_tokens_ids = for token_id <- 0..17, do: token_id +number_tokens_ids = Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) +vocabulary_token_ids = for token_id <- 0..model_info.spec.vocab_size, do: token_id + +string_token_ids = vocabulary_token_ids -- ([string_token_id] ++ special_tokens_ids) + +## sequence : 75, 33, 34, ... + +# State 0 1 +# chosen Token id 75 18 +# new state 1 3 + +## tensor +# State/token ids -> new state +## State / Token ids 0 1 2 ... 18 ... 33 ... 75 76 +## starting (0) -1 -1 -1 -1 -1 1 +## in_array (1) 4 2 6 +## in_number (2) +## in_addition (3) +## in_string (4) +## end_of_string (5) +## ending (6) +## done (7) + +## which tokens lead to which state from given state +state_transitions = + [ + # starting + {:starting, [array_start_token_id], :in_array}, + # in_array + {:in_array, number_tokens_ids, :in_number}, + {:in_array, [array_end_token_id], :ending}, + {:in_array, [string_token_id], :in_string}, + # in_number + {:in_number, number_tokens_ids, :in_number}, + {:in_number, [addition_token_id], :in_addition}, + {:in_number, [array_end_token_id], :ending}, + # in_addition + {:in_addition, number_tokens_ids, :in_number}, + {:in_addition, [string_token_id], :in_string}, + # in_string + {:in_string, string_token_ids, :in_string}, + {:in_string, [string_token_id], :end_of_string}, + # end_of_string + {:end_of_string, [addition_token_id], :in_addition}, + {:end_of_string, [array_end_token_id], :ending}, + # ending + {:ending, [end_of_sequence_token_id], :done} + ] + |> Enum.flat_map(fn {current_state, tensor_ids, next_state} -> + for tensor_id <- tensor_ids do + {state_to_num.(current_state), tensor_id, state_to_num.(next_state)} + end + end) + +dfa = %{ state_transitions: state_transitions, } + +generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: 24, + strategy: %{type: :multinomial_sampling, top_p: 0.6}, + dfa: dfa + ) + +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: sequence_length], + stream: false, + defn_options: [compiler: Nx.Defn.Evaluator] + ) + +%{results: [_result]} = Nx.Serving.run(serving, prompt) |> dbg + diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 5bc5a44f..6c84a8f8 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -5,6 +5,24 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do alias Bumblebee.Text.Generation.LogitsProcessing + describe "find_current_state" do + test "finds ambiguous token ids" do + ## rows = State / columns = token ids -> value = new state + ## 0 1 2 3 + ## 1 1 0 0 + ## 2 2 0 0 + ## token 1 is ambiguous + ambiguous_token_id = 1 + state_transitions_tensor = Nx.tensor([[0, 1, 2, 3], [1, 3, 0, 0], [2, 2, 0, 0]]) + token_column = state_transitions_tensor[[.., ambiguous_token_id]] |> Nx.squeeze() + {top_values, top_indices} = Nx.top_k(token_column, k: 2) + + ambiguous_token? = top_values[1] + + assert Nx.not_equal(ambiguous_token?, Nx.tensor(0)) + end + end + describe "suppressed_tokens_processor/3" do test "ignores the given tokens" do logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])