From fc8c8ffcebd75894bb935c595aaea6e9ce68645c Mon Sep 17 00:00:00 2001 From: Chris Date: Mon, 13 Oct 2025 14:10:01 +0200 Subject: [PATCH 01/18] array of numbers --- lib/bumblebee/text/generation.ex | 3 ++ .../text/generation/logits_processing.ex | 16 ++++++ lib/bumblebee/text/generation_config.ex | 4 ++ pair_programming.exs | 49 +++++++++++++++++++ 4 files changed, 72 insertions(+) create mode 100644 pair_programming.exs diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 935c4921..666fca78 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -369,6 +369,9 @@ defmodule Bumblebee.Text.Generation do if config.forced_token_ids do &forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids) end, + if config.allowed_token_ids != [] do + &allowed_tokens_processor(&1, &2, allowed_token_ids: config.allowed_token_ids) + end, if config.temperature && config.temperature != 1.0 do &temperature_processor(&1, &2, temperature: config.temperature) end diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index eff38e52..d261ebf2 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -11,6 +11,12 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do Nx.indexed_put(logits, indices, values) end + deftransform allowed_tokens_processor(logits, _context, opts \\ []) do + _opts = Keyword.validate!(opts, [:allowed_token_ids]) + + allow_token_ids(logits, opts[:allowed_token_ids]) + end + defn bos_token_processor(logits, context, opts \\ []) do opts = keyword!(opts, [:bos_token_id]) bos_token_id = opts[:bos_token_id] @@ -113,6 +119,16 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do |> Nx.put_slice([token_id], Nx.tensor([0], type: Nx.type(logits))) end + deftransformp allow_token_ids(logits, allowed_token_ids) do + # Convert allowed_token_ids to a tensor if it's a list + allowed_indices = Nx.tensor(allowed_token_ids) + allowed_logits = Nx.take(logits, allowed_indices) + suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) + + indices = Nx.new_axis(allowed_indices, -1) + Nx.indexed_put(suppressed_logits, indices, allowed_logits) + end + deftransformp ignore_token_id(logits, token_id) do Nx.put_slice( logits, diff --git a/lib/bumblebee/text/generation_config.ex b/lib/bumblebee/text/generation_config.ex index d7a6a9a0..2070c2e5 100644 --- a/lib/bumblebee/text/generation_config.ex +++ b/lib/bumblebee/text/generation_config.ex @@ -93,6 +93,10 @@ defmodule Bumblebee.Text.GenerationConfig do default: [], doc: "a list of token ids to suppress during generation" ], + allowed_token_ids: [ + default: [], + doc: "a list of token ids to enforce during generation (suppressing the all tokens that are not in the list)" + ], no_repeat_ngram_length: [ default: nil, doc: "when set, n-grams of the given length can occur only once in the generated sequence" diff --git a/pair_programming.exs b/pair_programming.exs new file mode 100644 index 00000000..3a6820fe --- /dev/null +++ b/pair_programming.exs @@ -0,0 +1,49 @@ +Mix.install([ + {:bumblebee, path: "../bumblebee_bitcrowd"}, + {:nx, "~> 0.10.0", override: true}, + {:emlx, github: "elixir-nx/emlx"}, +]) + +Nx.global_default_backend({EMLX.Backend, device: :gpu}) +repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"} +{:ok, model_info} = Bumblebee.load_model(repo, backend: {EMLX.Backend, device: :gpu}) +{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) +{:ok, generation_config} = Bumblebee.load_generation_config(repo) + +sequence_length = 512 +prompt = """ + Give me 10 random, single digit numbers in an array. + Valid examples are: + + [1] + [4,7] + [2,4,1] + [4,5,3,6] + [1,7,8,0] + [9,4,7,3,5,2] + [8,2,3,8,6,4,8] + [3,5,9] + [8,9,6,7] + """ + +generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: 24, + # allowed_token_ids: [0,1,2,28,32,33,34,35,36,37,38,39,40,41,75,77], + strategy: %{type: :multinomial_sampling, top_p: 0.6} + ) + + +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: sequence_length], + stream: false, + defn_options: [compiler: Nx.Defn.Evaluator] + ) + + + +{:ok, _pid} = + Supervisor.start_link([{Nx.Serving, name: Serving, serving: serving}], strategy: :one_for_one) + +Nx.Serving.run(serving, prompt) |> dbg From 141631b9b895559d227d6618b2dc6e1c09e284b4 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 13 Oct 2025 14:31:13 +0200 Subject: [PATCH 02/18] mix format --- pair_programming.exs | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/pair_programming.exs b/pair_programming.exs index 3a6820fe..494640dd 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -1,7 +1,7 @@ Mix.install([ {:bumblebee, path: "../bumblebee_bitcrowd"}, {:nx, "~> 0.10.0", override: true}, - {:emlx, github: "elixir-nx/emlx"}, + {:emlx, github: "elixir-nx/emlx"} ]) Nx.global_default_backend({EMLX.Backend, device: :gpu}) @@ -11,20 +11,21 @@ repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"} {:ok, generation_config} = Bumblebee.load_generation_config(repo) sequence_length = 512 + prompt = """ - Give me 10 random, single digit numbers in an array. - Valid examples are: - - [1] - [4,7] - [2,4,1] - [4,5,3,6] - [1,7,8,0] - [9,4,7,3,5,2] - [8,2,3,8,6,4,8] - [3,5,9] - [8,9,6,7] - """ +Give me 10 random, single digit numbers in an array. +Valid examples are: + +[1] +[4,7] +[2,4,1] +[4,5,3,6] +[1,7,8,0] +[9,4,7,3,5,2] +[8,2,3,8,6,4,8] +[3,5,9] +[8,9,6,7] +""" generation_config = Bumblebee.configure(generation_config, @@ -33,7 +34,6 @@ generation_config = strategy: %{type: :multinomial_sampling, top_p: 0.6} ) - serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config, compile: [batch_size: 1, sequence_length: sequence_length], @@ -41,8 +41,6 @@ serving = defn_options: [compiler: Nx.Defn.Evaluator] ) - - {:ok, _pid} = Supervisor.start_link([{Nx.Serving, name: Serving, serving: serving}], strategy: :one_for_one) From 494c7b92f28530403b358167be10d1c9ae978035 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 13 Oct 2025 14:31:50 +0200 Subject: [PATCH 03/18] find allowed token_ids in vocabulary --- pair_programming.exs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pair_programming.exs b/pair_programming.exs index 494640dd..2d8f3d8c 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -27,10 +27,21 @@ Valid examples are: [8,9,6,7] """ +allowed_tokens = ["[", "1", "2", ",", "]"] + +special_token_ids = + Bumblebee.Tokenizer.all_special_tokens(tokenizer) + |> Enum.map(&Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) + |> Enum.reject(&is_nil/1) + +allowed_token_ids = Enum.map(allowed_tokens, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) + +all_allowed_token_ids = special_token_ids ++ allowed_token_ids + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 24, - # allowed_token_ids: [0,1,2,28,32,33,34,35,36,37,38,39,40,41,75,77], + allowed_token_ids: all_allowed_token_ids, strategy: %{type: :multinomial_sampling, top_p: 0.6} ) From 6b36285d686e725625c0b8a1a1376108d37ce857 Mon Sep 17 00:00:00 2001 From: Chris Date: Mon, 13 Oct 2025 16:59:10 +0200 Subject: [PATCH 04/18] oh so very wip --- lib/bumblebee/text/generation.ex | 3 + .../text/generation/logits_processing.ex | 38 ++++++++++++ lib/bumblebee/text/generation_config.ex | 4 ++ pair_programming.exs | 61 ++++++++++++++++--- 4 files changed, 97 insertions(+), 9 deletions(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 666fca78..2b9547d1 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -372,6 +372,9 @@ defmodule Bumblebee.Text.Generation do if config.allowed_token_ids != [] do &allowed_tokens_processor(&1, &2, allowed_token_ids: config.allowed_token_ids) end, + if config.dfa do + &dfa_processor(&1, &2, dfa: config.dfa) + end, if config.temperature && config.temperature != 1.0 do &temperature_processor(&1, &2, temperature: config.temperature) end diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index d261ebf2..7b1889b0 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -3,6 +3,44 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do import Nx.Defn + deftransform dfa_processor(logits, context, opts \\ []) do + opts = Keyword.validate!(opts, [:dfa]) + dfa = opts[:dfa] + + ## figure out current state from context + current_state = figure_out_state(context, dfa) + ## figure out allowed tokens for next sampling + allowed_tokens = dfa.transitions[current_state] + ## pass allowed tokens into allow_token_ids + allow_token_ids(logits, allowed_tokens) + end + + ## figure out state from last token in context + deftransform figure_out_state(context, dfa) do + + last_token = context.sequence[0] + if context.length == 512 do + :starting + # if context.length == 1 do + # dbg("hon, honk") + # dbg(context.sequence) + # :starting + # else + # :in_number + # # last_token = context.sequence[context.length - 1] + + # # {state, _tokens} = + # # dfa.states + # # |> Enum.find(fn {_state, tokens} -> + # # last_token == nil || last_token in tokens + # # end) + + # # state + else + :in_number + end + end + deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do opts = Keyword.validate!(opts, [:suppressed_token_ids]) diff --git a/lib/bumblebee/text/generation_config.ex b/lib/bumblebee/text/generation_config.ex index 2070c2e5..20cfd14d 100644 --- a/lib/bumblebee/text/generation_config.ex +++ b/lib/bumblebee/text/generation_config.ex @@ -97,6 +97,10 @@ defmodule Bumblebee.Text.GenerationConfig do default: [], doc: "a list of token ids to enforce during generation (suppressing the all tokens that are not in the list)" ], + dfa: [ + default: nil, + doc: "the definition of a very simple deterministic finite automaton (dfa) for the generation" + ], no_repeat_ngram_length: [ default: nil, doc: "when set, n-grams of the given length can occur only once in the generated sequence" diff --git a/pair_programming.exs b/pair_programming.exs index 2d8f3d8c..fd8869d4 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -27,22 +27,65 @@ Valid examples are: [8,9,6,7] """ -allowed_tokens = ["[", "1", "2", ",", "]"] +numbers = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"] +start_token = ["["] +end_token = ["]"] +addition_token = [","] -special_token_ids = - Bumblebee.Tokenizer.all_special_tokens(tokenizer) - |> Enum.map(&Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) - |> Enum.reject(&is_nil/1) -allowed_token_ids = Enum.map(allowed_tokens, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) +# states +# * start_token -> start -> in_number, end -> numbers ++ end_token +# * numbers -> in_number -> in_number, addition, end -> numbers ++ addition_token ++ end_token +# * addition_token -> addition -> in_number -> numbers +# * end_token -> end -> END_OF_SEQUENCE -> END_OF_SEQUENCE + +# { +# start: start_token, +# numbers: numbers, +# } + +## am anfang war nix +## -> Start token +## last_state = inspect_last_token oder last state from stack + +## Am Anfang sind wir im start_token state und haben den start token schon +## die nächsten kandidaten wählen +## die große wahl +## inpect current token -> determine which state was chosen +## next loop + +## last token -> state + +end_of_sequence_token = [Bumblebee.Tokenizer.special_token(tokenizer, :eos)] + +## transitions +transitions = %{ + starting: Enum.map(start_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + in_array: Enum.map(numbers ++ end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + in_number: Enum.map(numbers ++ addition_token ++ end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + in_addition: Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + ending: Enum.map(end_of_sequence_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) +} +## states +states = %{ + starting: [], + in_array: Enum.map(start_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + in_number: Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + in_addition: Enum.map(addition_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + ending: Enum.map(end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) +} + +dfa = %{ + states: states, + transitions: transitions +} -all_allowed_token_ids = special_token_ids ++ allowed_token_ids generation_config = Bumblebee.configure(generation_config, max_new_tokens: 24, - allowed_token_ids: all_allowed_token_ids, - strategy: %{type: :multinomial_sampling, top_p: 0.6} + strategy: %{type: :multinomial_sampling, top_p: 0.6}, + dfa: dfa ) serving = From 63b3515eab60532b118f726cba386168ebcc0ef6 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Tue, 14 Oct 2025 11:52:53 +0200 Subject: [PATCH 05/18] working structured generation --- .../text/generation/logits_processing.ex | 90 +++++++++++++------ pair_programming.exs | 53 ++++++----- 2 files changed, 94 insertions(+), 49 deletions(-) diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index 7b1889b0..6eed8342 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -7,39 +7,75 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do opts = Keyword.validate!(opts, [:dfa]) dfa = opts[:dfa] - ## figure out current state from context - current_state = figure_out_state(context, dfa) + transitions_max_length = Map.values(dfa.transitions) |> Enum.map(&length(&1)) |> Enum.max() + + transitions_tensor = + dfa.transitions + |> Enum.with_index() + |> Enum.map(fn {{_state, token_ids}, index} -> {index, Nx.tensor(token_ids)} end) + |> Enum.map(fn {_idx, tensor} -> + Nx.pad(tensor, -1, [{0, transitions_max_length - Nx.size(tensor), 0}]) + end) + |> Nx.stack() + + states_tensor = Nx.broadcast(-1, {Nx.size(logits)}) + + states_indices = Enum.map(dfa.states, fn {key, _value} -> key end) |> Nx.tensor() + states_indices = Nx.new_axis(states_indices, -1) + + states_states = Enum.map(dfa.states, fn {_key, state} -> state end) |> Nx.tensor() + + states_tensor = Nx.indexed_put(states_tensor, states_indices, states_states) + + state_logits(logits, context, transitions_tensor, states_tensor) + end + + defn state_logits(logits, context, transitions_tensor, states_tensor) do + ## states tensor + ## 0 -> -1 + ## ... + ## 75 -> 1 (in_array) + ## .. + current_state = + if context.length == context.input_length do + 0 + else + last_token = context.sequence[context.length - 1] + states_tensor[last_token] + end + ## figure out allowed tokens for next sampling - allowed_tokens = dfa.transitions[current_state] + + ## transition tensor + ## 0 -> token_ids for starting (padded to largest dimension) + ## 1 -> token_ids for in_array (padded to largest dimension) + ## ... + allowed_tokens = transitions_tensor[current_state] + ## pass allowed tokens into allow_token_ids + # allow_token_ids(logits, allowed_tokens) allow_token_ids(logits, allowed_tokens) end ## figure out state from last token in context - deftransform figure_out_state(context, dfa) do - - last_token = context.sequence[0] - if context.length == 512 do - :starting - # if context.length == 1 do - # dbg("hon, honk") - # dbg(context.sequence) - # :starting - # else - # :in_number - # # last_token = context.sequence[context.length - 1] - - # # {state, _tokens} = - # # dfa.states - # # |> Enum.find(fn {_state, tokens} -> - # # last_token == nil || last_token in tokens - # # end) - - # # state - else - :in_number - end - end + # deftransform figure_out_state(context, dfa) do + # # _last_token = context.sequence[0] + + # if context.length == 512 do + # :starting + # else + # :in_number + # last_token = context.sequence[context.length - 1] + + # {state, _tokens} = + # dfa.states + # |> Enum.find(fn {_state, tokens} -> + # last_token == nil || last_token in tokens + # end) + + # state + # end + # end deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do opts = Keyword.validate!(opts, [:suppressed_token_ids]) diff --git a/pair_programming.exs b/pair_programming.exs index fd8869d4..f4852149 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -16,15 +16,7 @@ prompt = """ Give me 10 random, single digit numbers in an array. Valid examples are: -[1] -[4,7] -[2,4,1] -[4,5,3,6] -[1,7,8,0] -[9,4,7,3,5,2] -[8,2,3,8,6,4,8] -[3,5,9] -[8,9,6,7] +[8,2,3,8,6,4,8,6,4,8] """ numbers = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"] @@ -32,7 +24,6 @@ start_token = ["["] end_token = ["]"] addition_token = [","] - # states # * start_token -> start -> in_number, end -> numbers ++ end_token # * numbers -> in_number -> in_number, addition, end -> numbers ++ addition_token ++ end_token @@ -56,34 +47,52 @@ addition_token = [","] ## last token -> state -end_of_sequence_token = [Bumblebee.Tokenizer.special_token(tokenizer, :eos)] +end_of_sequence_token = Bumblebee.Tokenizer.special_token(tokenizer, :eos) + +states_to_num = %{ + starting: 0, + in_array: 1, + in_number: 2, + in_addition: 3, + ending: 4 +} ## transitions transitions = %{ starting: Enum.map(start_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), in_array: Enum.map(numbers ++ end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - in_number: Enum.map(numbers ++ addition_token ++ end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + in_number: + Enum.map( + numbers ++ addition_token ++ end_token, + &Bumblebee.Tokenizer.token_to_id(tokenizer, &1) + ), in_addition: Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - ending: Enum.map(end_of_sequence_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) + ending: Enum.map([end_of_sequence_token], &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) } + ## states -states = %{ - starting: [], - in_array: Enum.map(start_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - in_number: Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - in_addition: Enum.map(addition_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - ending: Enum.map(end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) -} +states = + %{ + starting: [], + in_array: Enum.map(start_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + in_number: Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + in_addition: Enum.map(addition_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), + ending: Enum.map(end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) + } + |> Enum.flat_map(fn {state, tensor_ids} -> + for tensor_id <- tensor_ids do + {tensor_id, states_to_num[state]} + end + end) dfa = %{ states: states, transitions: transitions } - generation_config = Bumblebee.configure(generation_config, - max_new_tokens: 24, + max_new_tokens: 48, strategy: %{type: :multinomial_sampling, top_p: 0.6}, dfa: dfa ) From ae5ba7b6e492ffd78a163fdd78d3b16891dd848c Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 14 Oct 2025 13:36:42 +0200 Subject: [PATCH 06/18] [#SAMPLE-2] Sampling an array of strings https://bitcrowd.atlassian.net/browse/SAMPLE-2 From 85027c5741412d1bd99cc9454ef5a752f78e6870 Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 14 Oct 2025 14:20:15 +0200 Subject: [PATCH 07/18] wip --- pair_programming.exs | 67 ++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 37 deletions(-) diff --git a/pair_programming.exs b/pair_programming.exs index f4852149..5cd141ef 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -20,54 +20,47 @@ Valid examples are: """ numbers = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"] -start_token = ["["] -end_token = ["]"] -addition_token = [","] +array_start_token = ["["] +array_end_token = ["]"] +array_addition_token = [","] +# String Token would require ! (like "everything, just without ....) +string_token ="\"" # Token 18 -# states -# * start_token -> start -> in_number, end -> numbers ++ end_token -# * numbers -> in_number -> in_number, addition, end -> numbers ++ addition_token ++ end_token -# * addition_token -> addition -> in_number -> numbers -# * end_token -> end -> END_OF_SEQUENCE -> END_OF_SEQUENCE - -# { -# start: start_token, -# numbers: numbers, -# } - -## am anfang war nix -## -> Start token -## last_state = inspect_last_token oder last state from stack - -## Am Anfang sind wir im start_token state und haben den start token schon -## die nächsten kandidaten wählen -## die große wahl -## inpect current token -> determine which state was chosen -## next loop - -## last token -> state - -end_of_sequence_token = Bumblebee.Tokenizer.special_token(tokenizer, :eos) +# ToDo: should be a list -> idx states_to_num = %{ starting: 0, in_array: 1, in_number: 2, in_addition: 3, - ending: 4 + in_string: 4, + ending: 5 } +# ------------------------------------- above chars ------------------------------ # +# ------------------------------------- below tokens ------------------------------ # + +end_of_sequence_token_ids = [Bumblebee.Tokenizer.special_token_id(tokenizer, :eos)] +special_tokens_ids = for token_id <- 0..17, do: token_id + +array_start_token_ids = Enum.map(array_start_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) +array_end_token_ids = Enum.map(array_end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) +addition_token_ids = Enum.map(array_addition_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) +string_token_ids = Enum.map(string_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) + +number_tokens_ids = Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) +vocabulary_token_ids = for token_id <- 0..model_info.vocabulary_size, do: token_id +forbidden_string_tokens_ids = string_token_ids -- special_tokens_ids -- +string_token_ids = vocabulary_token_ids -- forbidden_string_tokens_ids # including ", which ends the string + ## transitions transitions = %{ - starting: Enum.map(start_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - in_array: Enum.map(numbers ++ end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - in_number: - Enum.map( - numbers ++ addition_token ++ end_token, - &Bumblebee.Tokenizer.token_to_id(tokenizer, &1) - ), - in_addition: Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - ending: Enum.map([end_of_sequence_token], &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) + starting: start_token_ids, + in_array: number_tokens_ids ++ end_token_ids ++ start_string_token_ids, # todo start string token + in_number: number_tokens_ids ++ addition_token_ids ++ end_token_ids, + in_addition: number_tokens_ids, + in_string: string_token_ids -- forbidden_strin_tokens_ids, + ending: end_of_sequence_token_ids } ## states From 8e1f7e3de1c5fcd77bed1d941e9316bc4ea5bc6a Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Tue, 14 Oct 2025 15:18:25 +0200 Subject: [PATCH 08/18] wip 2 --- pair_programming.exs | 61 +++++++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/pair_programming.exs b/pair_programming.exs index 5cd141ef..9c056712 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -34,43 +34,64 @@ states_to_num = %{ in_number: 2, in_addition: 3, in_string: 4, - ending: 5 + end_of_string: 5, + ending: 6 } # ------------------------------------- above chars ------------------------------ # # ------------------------------------- below tokens ------------------------------ # -end_of_sequence_token_ids = [Bumblebee.Tokenizer.special_token_id(tokenizer, :eos)] -special_tokens_ids = for token_id <- 0..17, do: token_id - -array_start_token_ids = Enum.map(array_start_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) -array_end_token_ids = Enum.map(array_end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) -addition_token_ids = Enum.map(array_addition_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) -string_token_ids = Enum.map(string_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) +array_start_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_start_token)) +array_end_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_end_token)) +addition_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_addition_token) +string_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, string_token) +end_of_sequence_token_id = Bumblebee.Tokenizer.special_token_id(tokenizer, :eos) +special_tokens_ids = for token_id <- 0..17, do: token_id number_tokens_ids = Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) vocabulary_token_ids = for token_id <- 0..model_info.vocabulary_size, do: token_id -forbidden_string_tokens_ids = string_token_ids -- special_tokens_ids -- -string_token_ids = vocabulary_token_ids -- forbidden_string_tokens_ids # including ", which ends the string + +string_token_ids = vocabulary_token_ids -- [string_token_id] -- special_tokens_ids ## transitions transitions = %{ - starting: start_token_ids, - in_array: number_tokens_ids ++ end_token_ids ++ start_string_token_ids, # todo start string token - in_number: number_tokens_ids ++ addition_token_ids ++ end_token_ids, - in_addition: number_tokens_ids, - in_string: string_token_ids -- forbidden_strin_tokens_ids, - ending: end_of_sequence_token_ids + starting: [array_start_token_id], + in_array: number_tokens_ids ++ [array_end_token_id, string_token_id], # todo start string token + in_number: number_tokens_ids ++ [addition_token_id, end_token_id], + in_addition: number_tokens_ids ++ [string_token_id], + in_string: string_token_ids ++ [string_token_id], + end_of_string: [addition_token_id, array_end_token_id], + ending: [end_of_sequence_token_id] } +## sequence : 75, 33, 34, ... + +# State 0 1 +# chosen Token id 75 18 +# new state 1 3 + +## tensor +# State/token ids -> new state +## State / Token ids 0 1 2 ... 18 ... 33 ... 75 76 +## starting (0) -1 -1 -1 -1 -1 1 +## in_array (1) 4 2 6 +## in_number (2) +## in_addition (3) +## in_string (4) +## end_of_string (5) +## ending (6) + + ## states states = %{ starting: [], - in_array: Enum.map(start_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - in_number: Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - in_addition: Enum.map(addition_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)), - ending: Enum.map(end_token, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) + in_array: [array_start_token_id], + in_number: number_token_ids, + in_addition: [addition_token_id], + in_string: [string_token_id], + end_of_string: [string_token_id], + ending: [array_end_token_id] } |> Enum.flat_map(fn {state, tensor_ids} -> for tensor_id <- tensor_ids do From 67c5d1ab5a6a89e7ce48d987290bae42c981bbaa Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Tue, 14 Oct 2025 16:33:04 +0200 Subject: [PATCH 09/18] with dfa tensors --- .../text/generation/logits_processing.ex | 58 +++++++++-- pair_programming.exs | 99 +++++++++++-------- 2 files changed, 106 insertions(+), 51 deletions(-) diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index 6eed8342..30f58bbf 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -7,27 +7,65 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do opts = Keyword.validate!(opts, [:dfa]) dfa = opts[:dfa] - transitions_max_length = Map.values(dfa.transitions) |> Enum.map(&length(&1)) |> Enum.max() + allowed_token_ids_max_length = + Map.values(dfa.allowed_token_ids_for_state) + |> Enum.map(&length(&1)) + |> Enum.max() - transitions_tensor = - dfa.transitions + allowed_token_ids_tensor = + dfa.allowed_token_ids_for_state |> Enum.with_index() |> Enum.map(fn {{_state, token_ids}, index} -> {index, Nx.tensor(token_ids)} end) |> Enum.map(fn {_idx, tensor} -> - Nx.pad(tensor, -1, [{0, transitions_max_length - Nx.size(tensor), 0}]) + Nx.pad(tensor, -1, [{0, allowed_token_ids_max_length - Nx.size(tensor), 0}]) end) |> Nx.stack() - states_tensor = Nx.broadcast(-1, {Nx.size(logits)}) + num_states = Map.keys(dfa.allowed_token_ids_for_state) |> length() - states_indices = Enum.map(dfa.states, fn {key, _value} -> key end) |> Nx.tensor() - states_indices = Nx.new_axis(states_indices, -1) + states_transition_tensor = Nx.broadcast(-1, {num_states, Nx.size(logits)}) - states_states = Enum.map(dfa.states, fn {_key, state} -> state end) |> Nx.tensor() + states_transitions_tensor = + for {current_state, token_id, next_state} <- dfa.state_transitions, + reduce: states_transition_tensor do + states_transition_tensor -> + Nx.indexed_put( + states_transition_tensor, + Nx.tensor([current_state, token_id]), + next_state + ) + end + + initial_state = Nx.tensor([0]) |> Nx.vectorize(batch: 1) + + current_state = + find_current_state( + initial_state, + states_transitions_tensor, + context.sequence, + context.input_length, + context.length + ) + + allowed_tokens = allowed_token_ids_tensor[current_state] + + allow_token_ids(logits, allowed_tokens) + end + + defn find_current_state(initial_state, states_transitions_tensor, sequence, input_length, current_length) do + generated_length = current_length - input_length + + {state, _i, _sequence, _input_length, _generated_length, _states_transitions_tensor} = + while {state = initial_state, i = 0, sequence, input_length, generated_length, + states_transitions_tensor}, + Nx.less(i, generated_length) do + chosen_token = sequence[input_length + i] + new_state = states_transitions_tensor[[state, chosen_token]] + {new_state, i + 1, sequence, input_length, generated_length, states_transitions_tensor} + end - states_tensor = Nx.indexed_put(states_tensor, states_indices, states_states) - state_logits(logits, context, transitions_tensor, states_tensor) + state end defn state_logits(logits, context, transitions_tensor, states_tensor) do diff --git a/pair_programming.exs b/pair_programming.exs index 9c056712..02bb2a36 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -13,51 +13,55 @@ repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"} sequence_length = 512 prompt = """ -Give me 10 random, single digit numbers in an array. +Give me an array that contains a mix of numbers and text. +There MUST be at least one number and one text. Valid examples are: -[8,2,3,8,6,4,8,6,4,8] +["hello",89,"hola",6,4,8] """ numbers = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"] -array_start_token = ["["] -array_end_token = ["]"] -array_addition_token = [","] +array_start_token = "[" +array_end_token = "]" +array_addition_token = "," # String Token would require ! (like "everything, just without ....) -string_token ="\"" # Token 18 - +# Token 18 +string_token = "\"" # ToDo: should be a list -> idx -states_to_num = %{ - starting: 0, - in_array: 1, - in_number: 2, - in_addition: 3, - in_string: 4, - end_of_string: 5, - ending: 6 -} +states = [ + :starting, + :in_array, + :in_number, + :in_addition, + :in_string, + :end_of_string, + :ending +] + +state_to_num = fn state -> Enum.find_index(states, & &1 == state) end # ------------------------------------- above chars ------------------------------ # # ------------------------------------- below tokens ------------------------------ # -array_start_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_start_token)) -array_end_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_end_token)) +array_start_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_start_token) +array_end_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_end_token) addition_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, array_addition_token) string_token_id = Bumblebee.Tokenizer.token_to_id(tokenizer, string_token) end_of_sequence_token_id = Bumblebee.Tokenizer.special_token_id(tokenizer, :eos) special_tokens_ids = for token_id <- 0..17, do: token_id number_tokens_ids = Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) -vocabulary_token_ids = for token_id <- 0..model_info.vocabulary_size, do: token_id +vocabulary_token_ids = for token_id <- 0..model_info.spec.vocab_size, do: token_id -string_token_ids = vocabulary_token_ids -- [string_token_id] -- special_tokens_ids +string_token_ids = vocabulary_token_ids -- ([string_token_id] -- special_tokens_ids) -## transitions -transitions = %{ +## which tokens are allowed +allowed_token_ids_for_state = %{ starting: [array_start_token_id], - in_array: number_tokens_ids ++ [array_end_token_id, string_token_id], # todo start string token - in_number: number_tokens_ids ++ [addition_token_id, end_token_id], + # todo start string token + in_array: number_tokens_ids ++ [array_end_token_id, string_token_id], + in_number: number_tokens_ids ++ [addition_token_id, array_end_token_id], in_addition: number_tokens_ids ++ [string_token_id], in_string: string_token_ids ++ [string_token_id], end_of_string: [addition_token_id, array_end_token_id], @@ -81,32 +85,45 @@ transitions = %{ ## end_of_string (5) ## ending (6) - -## states -states = - %{ - starting: [], - in_array: [array_start_token_id], - in_number: number_token_ids, - in_addition: [addition_token_id], - in_string: [string_token_id], - end_of_string: [string_token_id], - ending: [array_end_token_id] - } - |> Enum.flat_map(fn {state, tensor_ids} -> +## which tokens lead to which state from given state +state_transitions = + [ + # starting + {:starting, [array_start_token_id], :in_array}, + # in_array + {:in_array, number_tokens_ids, :in_number}, + {:in_array, [array_end_token_id], :ending}, + {:in_array, [string_token_id], :in_string}, + # in_number + {:in_number, number_tokens_ids, :in_number}, + {:in_number, [addition_token_id], :in_array}, + {:in_number, [array_end_token_id], :ending}, + # in_addition + {:in_addition, number_tokens_ids, :in_addition}, + {:in_addition, [string_token_id], :in_string}, + # in_string + {:in_string, string_token_ids, :in_string}, + {:in_string, [string_token_id], :end_of_string}, + # end_of_string + {:end_of_string, [addition_token_id], :in_addition}, + {:end_of_string, [array_end_token_id], :ending} + # ending + # {:ending, [], :ending} + ] + |> Enum.flat_map(fn {current_state, tensor_ids, next_state} -> for tensor_id <- tensor_ids do - {tensor_id, states_to_num[state]} + {state_to_num.(current_state), tensor_id, state_to_num.(next_state)} end end) dfa = %{ - states: states, - transitions: transitions + state_transitions: state_transitions, + allowed_token_ids_for_state: allowed_token_ids_for_state } generation_config = Bumblebee.configure(generation_config, - max_new_tokens: 48, + max_new_tokens: 24, strategy: %{type: :multinomial_sampling, top_p: 0.6}, dfa: dfa ) From 55f5c780b583b9e9ed7fce8c0b145d696c1861c0 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 15 Oct 2025 09:55:10 +0200 Subject: [PATCH 10/18] with skipping for unambiguous tokens --- .../text/generation/logits_processing.ex | 115 +++++++++++++++--- pair_programming.exs | 16 ++- 2 files changed, 112 insertions(+), 19 deletions(-) diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index 30f58bbf..5c8b7eb6 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -36,38 +36,117 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do ) end - initial_state = Nx.tensor([0]) |> Nx.vectorize(batch: 1) - current_state = - find_current_state( - initial_state, - states_transitions_tensor, - context.sequence, - context.input_length, - context.length - ) + if dfa[:ambiguous_token_ids] do + ambiguous_token_ids = Nx.tensor(dfa.ambiguous_token_ids) + + {token_ids, states} = Enum.unzip(dfa.simple_lookup) + + token_ids = + Nx.tensor(token_ids) + |> Nx.new_axis(-1) + + states = Nx.tensor(states) + + simple_lookup = + Nx.broadcast(-1, {Nx.size(logits)}) + |> Nx.indexed_put(token_ids, states) + + initial_state = Nx.tensor([0]) |> Nx.vectorize(batch: 1) + + find_current_state_with_skip( + initial_state, + states_transitions_tensor, + context.sequence, + context.input_length, + context.length, + ambiguous_token_ids, + simple_lookup + ) + else + initial_state = Nx.tensor([0]) |> Nx.vectorize(batch: 1) + + find_current_state( + initial_state, + states_transitions_tensor, + context.sequence, + context.input_length, + context.length + ) + end allowed_tokens = allowed_token_ids_tensor[current_state] allow_token_ids(logits, allowed_tokens) end - defn find_current_state(initial_state, states_transitions_tensor, sequence, input_length, current_length) do + defn find_current_state_with_skip( + initial_state, + states_transitions_tensor, + sequence, + input_length, + current_length, + ambiguous_token_ids, + simple_lookup + ) do generated_length = current_length - input_length + last_token_id = sequence[current_length] - {state, _i, _sequence, _input_length, _generated_length, _states_transitions_tensor} = - while {state = initial_state, i = 0, sequence, input_length, generated_length, - states_transitions_tensor}, - Nx.less(i, generated_length) do - chosen_token = sequence[input_length + i] - new_state = states_transitions_tensor[[state, chosen_token]] - {new_state, i + 1, sequence, input_length, generated_length, states_transitions_tensor} - end + state = + cond do + generated_length == 0 -> + initial_state + + Nx.any(Nx.equal(last_token_id, ambiguous_token_ids)) -> + {state, _i, _sequence, _input_length, _generated_length, _states_transitions_tensor} = + while {state = initial_state, i = 0, sequence, input_length, generated_length, + states_transitions_tensor}, + Nx.less(i, generated_length) do + chosen_token = sequence[input_length + i] + new_state = states_transitions_tensor[[state, chosen_token]] + + {new_state, i + 1, sequence, input_length, generated_length, + states_transitions_tensor} + end + state + + true -> + simple_lookup[last_token_id] + end state end + defn find_current_state( + initial_state, + states_transitions_tensor, + sequence, + input_length, + current_length + ) do + generated_length = current_length - input_length + + cond do + generated_length == 0 -> + initial_state + + true -> + {state, _i, _sequence, _input_length, _generated_length, _states_transitions_tensor} = + while {state = initial_state, i = 0, sequence, input_length, generated_length, + states_transitions_tensor}, + Nx.less(i, generated_length) do + chosen_token = sequence[input_length + i] + new_state = states_transitions_tensor[[state, chosen_token]] + + {new_state, i + 1, sequence, input_length, generated_length, + states_transitions_tensor} + end + + state + end + end + defn state_logits(logits, context, transitions_tensor, states_tensor) do ## states tensor ## 0 -> -1 diff --git a/pair_programming.exs b/pair_programming.exs index 02bb2a36..f22b64c9 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -116,9 +116,23 @@ state_transitions = end end) +ambiguous_token_ids = + state_transitions + |> Enum.map(fn {_current_state, tensor_id, next_state} -> {tensor_id, next_state} end) + |> Enum.dedup() + |> Enum.frequencies_by(fn {tensor_id, _state} -> tensor_id end) + |> Enum.filter(fn {_tensor_id, count} -> count > 1 end) + |> Enum.map(fn {tensor_id, _count} -> tensor_id end) + +simple_lookup = for {state, token_id, _next_state} <- state_transitions, token_id not in ambiguous_token_ids do + {token_id, state} + end + dfa = %{ state_transitions: state_transitions, - allowed_token_ids_for_state: allowed_token_ids_for_state + allowed_token_ids_for_state: allowed_token_ids_for_state, + ambiguous_token_ids: ambiguous_token_ids, + simple_lookup: simple_lookup } generation_config = From 19aec28024a9a2a7d51beb4647810387d11ce2f5 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 15 Oct 2025 09:55:29 +0200 Subject: [PATCH 11/18] benchmarks in script --- pair_programming.exs | 47 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/pair_programming.exs b/pair_programming.exs index f22b64c9..9491df55 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -1,7 +1,8 @@ Mix.install([ {:bumblebee, path: "../bumblebee_bitcrowd"}, {:nx, "~> 0.10.0", override: true}, - {:emlx, github: "elixir-nx/emlx"} + {:emlx, github: "elixir-nx/emlx"}, + {:benchee, "~> 1.0"} ]) Nx.global_default_backend({EMLX.Backend, device: :gpu}) @@ -149,7 +150,45 @@ serving = defn_options: [compiler: Nx.Defn.Evaluator] ) -{:ok, _pid} = - Supervisor.start_link([{Nx.Serving, name: Serving, serving: serving}], strategy: :one_for_one) +%{results: [_result]} = Nx.Serving.run(serving, prompt) |> dbg + +# IO.puts result.text + + +serving_fn = fn max_new_tokens, dfa -> + generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: max_new_tokens, + strategy: %{type: :multinomial_sampling, top_p: 0.6}, + dfa: dfa + ) + + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: sequence_length], + stream: false, + defn_options: [compiler: Nx.Defn.Evaluator] + ) + + end + +serving_dfa_8 = serving_fn.(8, dfa) +serving_dfa_16 = serving_fn.(16, dfa) +serving_dfa_8_no_skip = serving_fn.(8, Map.delete(dfa, :ambiguous_token_ids)) +serving_dfa_16_no_skip = serving_fn.(16, Map.delete(dfa, :ambiguous_token_ids)) +serving_no_dfa_8 = serving_fn.(8, nil) +serving_no_dfa_16 = serving_fn.(16, nil) + +Benchee.run( + %{ + "max_new_tokens = 8" => fn -> Nx.Serving.run(serving_dfa_8, prompt) end, + "max_new_tokens = 16" => fn -> Nx.Serving.run(serving_dfa_16, prompt) end, + "no skip: max_new_tokens = 8" => fn -> Nx.Serving.run(serving_dfa_8_no_skip, prompt) end, + "no skip: max_new_tokens = 16" => fn -> Nx.Serving.run(serving_dfa_16_no_skip, prompt) end, + "no dfa: max_new_tokens = 8" => fn -> Nx.Serving.run(serving_no_dfa_8, prompt) end, + "no dfa: max_new_tokens = 16" => fn -> Nx.Serving.run(serving_no_dfa_16, prompt) end, + }, + time: 30, + memory_time: 2 +) + -Nx.Serving.run(serving, prompt) |> dbg From b7e9a1a110589536ffef1969962ffff682728746 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 15 Oct 2025 17:42:47 +0200 Subject: [PATCH 12/18] single tensor for dfa definition --- .../text/generation/logits_processing.ex | 187 ++++-------------- pair_programming.exs | 46 +---- 2 files changed, 51 insertions(+), 182 deletions(-) diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index 5c8b7eb6..3eb7ee10 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -7,193 +7,88 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do opts = Keyword.validate!(opts, [:dfa]) dfa = opts[:dfa] - allowed_token_ids_max_length = - Map.values(dfa.allowed_token_ids_for_state) - |> Enum.map(&length(&1)) - |> Enum.max() + num_states = + Enum.dedup_by(dfa.state_transitions, fn {state, _token_id, _next_state} -> state end) + |> length() - allowed_token_ids_tensor = - dfa.allowed_token_ids_for_state - |> Enum.with_index() - |> Enum.map(fn {{_state, token_ids}, index} -> {index, Nx.tensor(token_ids)} end) - |> Enum.map(fn {_idx, tensor} -> - Nx.pad(tensor, -1, [{0, allowed_token_ids_max_length - Nx.size(tensor), 0}]) - end) - |> Nx.stack() + state_transition_tensor = Nx.broadcast(0, {num_states, Nx.size(logits)}) - num_states = Map.keys(dfa.allowed_token_ids_for_state) |> length() - - states_transition_tensor = Nx.broadcast(-1, {num_states, Nx.size(logits)}) - - states_transitions_tensor = + state_transitions_tensor = for {current_state, token_id, next_state} <- dfa.state_transitions, - reduce: states_transition_tensor do - states_transition_tensor -> + reduce: state_transition_tensor do + state_transition_tensor -> Nx.indexed_put( - states_transition_tensor, + state_transition_tensor, Nx.tensor([current_state, token_id]), next_state ) end - current_state = - if dfa[:ambiguous_token_ids] do - ambiguous_token_ids = Nx.tensor(dfa.ambiguous_token_ids) - - {token_ids, states} = Enum.unzip(dfa.simple_lookup) - - token_ids = - Nx.tensor(token_ids) - |> Nx.new_axis(-1) - - states = Nx.tensor(states) - - simple_lookup = - Nx.broadcast(-1, {Nx.size(logits)}) - |> Nx.indexed_put(token_ids, states) - - initial_state = Nx.tensor([0]) |> Nx.vectorize(batch: 1) - - find_current_state_with_skip( - initial_state, - states_transitions_tensor, - context.sequence, - context.input_length, - context.length, - ambiguous_token_ids, - simple_lookup - ) - else - initial_state = Nx.tensor([0]) |> Nx.vectorize(batch: 1) + initial_state = Nx.tensor([0]) |> Nx.vectorize(batch: 1) - find_current_state( - initial_state, - states_transitions_tensor, - context.sequence, - context.input_length, - context.length - ) - end + current_state = + find_current_state( + initial_state, + state_transitions_tensor, + context.sequence, + context.input_length, + context.length + ) - allowed_tokens = allowed_token_ids_tensor[current_state] + suppressed_logits = Nx.fill(logits, Nx.Constants.neg_infinity(), type: Nx.type(logits)) + logits = Nx.select(state_transitions_tensor[current_state], logits, suppressed_logits) - allow_token_ids(logits, allowed_tokens) + logits end - defn find_current_state_with_skip( + defn find_current_state( initial_state, - states_transitions_tensor, + state_transitions_tensor, sequence, input_length, - current_length, - ambiguous_token_ids, - simple_lookup + current_length ) do generated_length = current_length - input_length + last_token_id = sequence[current_length] + token_column = state_transitions_tensor[[.., last_token_id]] |> Nx.squeeze() + + # top_k gives two top values + indices of the column + # if the token is unambiguous, there is only one value != 0 in the column (that's top_values[0]) + # if top_values[1] != 0, there must be two values != 0 in the column, so it's ambiguous + {top_values, top_indices} = Nx.top_k(token_column, k: 2) + + ambiguous_token? = Nx.logical_not(top_values[1]) state = cond do generated_length == 0 -> initial_state - Nx.any(Nx.equal(last_token_id, ambiguous_token_ids)) -> + ambiguous_token? -> {state, _i, _sequence, _input_length, _generated_length, _states_transitions_tensor} = while {state = initial_state, i = 0, sequence, input_length, generated_length, - states_transitions_tensor}, + state_transitions_tensor}, Nx.less(i, generated_length) do chosen_token = sequence[input_length + i] - new_state = states_transitions_tensor[[state, chosen_token]] + new_state = state_transitions_tensor[[state, chosen_token]] {new_state, i + 1, sequence, input_length, generated_length, - states_transitions_tensor} + state_transitions_tensor} end state true -> - simple_lookup[last_token_id] + # we know that top_indices[0] is the row index for the only token id != 0 + # this is our new state! + top_indices[0] end - state + print_value(state, label: "state") + # state end - defn find_current_state( - initial_state, - states_transitions_tensor, - sequence, - input_length, - current_length - ) do - generated_length = current_length - input_length - - cond do - generated_length == 0 -> - initial_state - - true -> - {state, _i, _sequence, _input_length, _generated_length, _states_transitions_tensor} = - while {state = initial_state, i = 0, sequence, input_length, generated_length, - states_transitions_tensor}, - Nx.less(i, generated_length) do - chosen_token = sequence[input_length + i] - new_state = states_transitions_tensor[[state, chosen_token]] - - {new_state, i + 1, sequence, input_length, generated_length, - states_transitions_tensor} - end - - state - end - end - - defn state_logits(logits, context, transitions_tensor, states_tensor) do - ## states tensor - ## 0 -> -1 - ## ... - ## 75 -> 1 (in_array) - ## .. - current_state = - if context.length == context.input_length do - 0 - else - last_token = context.sequence[context.length - 1] - states_tensor[last_token] - end - - ## figure out allowed tokens for next sampling - - ## transition tensor - ## 0 -> token_ids for starting (padded to largest dimension) - ## 1 -> token_ids for in_array (padded to largest dimension) - ## ... - allowed_tokens = transitions_tensor[current_state] - - ## pass allowed tokens into allow_token_ids - # allow_token_ids(logits, allowed_tokens) - allow_token_ids(logits, allowed_tokens) - end - - ## figure out state from last token in context - # deftransform figure_out_state(context, dfa) do - # # _last_token = context.sequence[0] - - # if context.length == 512 do - # :starting - # else - # :in_number - # last_token = context.sequence[context.length - 1] - - # {state, _tokens} = - # dfa.states - # |> Enum.find(fn {_state, tokens} -> - # last_token == nil || last_token in tokens - # end) - - # state - # end - # end - deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do opts = Keyword.validate!(opts, [:suppressed_token_ids]) diff --git a/pair_programming.exs b/pair_programming.exs index 9491df55..7a264eba 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -37,7 +37,8 @@ states = [ :in_addition, :in_string, :end_of_string, - :ending + :ending, + :done ] state_to_num = fn state -> Enum.find_index(states, & &1 == state) end @@ -57,18 +58,6 @@ vocabulary_token_ids = for token_id <- 0..model_info.spec.vocab_size, do: token_ string_token_ids = vocabulary_token_ids -- ([string_token_id] -- special_tokens_ids) -## which tokens are allowed -allowed_token_ids_for_state = %{ - starting: [array_start_token_id], - # todo start string token - in_array: number_tokens_ids ++ [array_end_token_id, string_token_id], - in_number: number_tokens_ids ++ [addition_token_id, array_end_token_id], - in_addition: number_tokens_ids ++ [string_token_id], - in_string: string_token_ids ++ [string_token_id], - end_of_string: [addition_token_id, array_end_token_id], - ending: [end_of_sequence_token_id] -} - ## sequence : 75, 33, 34, ... # State 0 1 @@ -85,6 +74,7 @@ allowed_token_ids_for_state = %{ ## in_string (4) ## end_of_string (5) ## ending (6) +## done (7) ## which tokens lead to which state from given state state_transitions = @@ -97,19 +87,19 @@ state_transitions = {:in_array, [string_token_id], :in_string}, # in_number {:in_number, number_tokens_ids, :in_number}, - {:in_number, [addition_token_id], :in_array}, + {:in_number, [addition_token_id], :in_addition}, {:in_number, [array_end_token_id], :ending}, # in_addition - {:in_addition, number_tokens_ids, :in_addition}, + {:in_addition, number_tokens_ids, :in_number}, {:in_addition, [string_token_id], :in_string}, # in_string {:in_string, string_token_ids, :in_string}, {:in_string, [string_token_id], :end_of_string}, # end_of_string {:end_of_string, [addition_token_id], :in_addition}, - {:end_of_string, [array_end_token_id], :ending} + {:end_of_string, [array_end_token_id], :ending}, # ending - # {:ending, [], :ending} + {:ending, [end_of_sequence_token_id], :done} ] |> Enum.flat_map(fn {current_state, tensor_ids, next_state} -> for tensor_id <- tensor_ids do @@ -117,24 +107,7 @@ state_transitions = end end) -ambiguous_token_ids = - state_transitions - |> Enum.map(fn {_current_state, tensor_id, next_state} -> {tensor_id, next_state} end) - |> Enum.dedup() - |> Enum.frequencies_by(fn {tensor_id, _state} -> tensor_id end) - |> Enum.filter(fn {_tensor_id, count} -> count > 1 end) - |> Enum.map(fn {tensor_id, _count} -> tensor_id end) - -simple_lookup = for {state, token_id, _next_state} <- state_transitions, token_id not in ambiguous_token_ids do - {token_id, state} - end - -dfa = %{ - state_transitions: state_transitions, - allowed_token_ids_for_state: allowed_token_ids_for_state, - ambiguous_token_ids: ambiguous_token_ids, - simple_lookup: simple_lookup -} +dfa = %{ state_transitions: state_transitions, } generation_config = Bumblebee.configure(generation_config, @@ -154,7 +127,7 @@ serving = # IO.puts result.text - +run_benchmarks = fn -> serving_fn = fn max_new_tokens, dfa -> generation_config = Bumblebee.configure(generation_config, @@ -190,5 +163,6 @@ Benchee.run( time: 30, memory_time: 2 ) +end From 1872d1ccd95c833e89d63248c167384380e92ce3 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 16 Oct 2025 09:44:45 +0200 Subject: [PATCH 13/18] fix condition for ambiguous token --- .../text/generation/logits_processing.ex | 2 +- .../generation/logits_processing_test.exs | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index 3eb7ee10..87632b82 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -58,7 +58,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do # if top_values[1] != 0, there must be two values != 0 in the column, so it's ambiguous {top_values, top_indices} = Nx.top_k(token_column, k: 2) - ambiguous_token? = Nx.logical_not(top_values[1]) + ambiguous_token? = Nx.not_equal(top_values[1], Nx.tensor(0)) state = cond do diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 5bc5a44f..b996e9ef 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -5,6 +5,25 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do alias Bumblebee.Text.Generation.LogitsProcessing + describe "find_current_state" do + test "finds ambiguous token ids" do + ## rows = State / columns = token ids -> value = new state + ## 0 1 2 3 + ## 1 1 0 0 + ## 2 2 0 0 + ## token 1 is ambiguous + ambiguous_token_id = 1 + state_transitions_tensor = Nx.tensor([[0,1,2,3], [1, 3, 0, 0], [2, 2, 0, 0 ]]) |> dbg + token_column = state_transitions_tensor[[.., ambiguous_token_id]] |> Nx.squeeze() |> dbg + {top_values, top_indices} = Nx.top_k(token_column, k: 2) |> dbg + + ambiguous_token? = top_values[1] |> dbg + + # assert(ambiguous_token?, Nx.tensor(1)) + assert Nx.not_equal(ambiguous_token?, Nx.tensor(0)) + # assert Nx.tensor(0) + end + end describe "suppressed_tokens_processor/3" do test "ignores the given tokens" do logits = Nx.tensor([1.0, 2.0, 3.0, 4.0]) From 7a0892c1dbd7a648af912ad3fc4f37364e47ac1f Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 16 Oct 2025 09:45:04 +0200 Subject: [PATCH 14/18] fix string_token_ids definition --- pair_programming.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pair_programming.exs b/pair_programming.exs index 7a264eba..5146325b 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -56,7 +56,7 @@ special_tokens_ids = for token_id <- 0..17, do: token_id number_tokens_ids = Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) vocabulary_token_ids = for token_id <- 0..model_info.spec.vocab_size, do: token_id -string_token_ids = vocabulary_token_ids -- ([string_token_id] -- special_tokens_ids) +string_token_ids = vocabulary_token_ids -- [string_token_id] -- special_tokens_ids ## sequence : 75, 33, 34, ... From 75faee3b7305ca5eebe8f0920de251ae1d93b6fc Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 16 Oct 2025 09:48:41 +0200 Subject: [PATCH 15/18] remove debug prints --- lib/bumblebee/text/generation/logits_processing.ex | 3 +-- .../text/generation/logits_processing_test.exs | 11 +++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index 87632b82..268c68e3 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -85,8 +85,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do top_indices[0] end - print_value(state, label: "state") - # state + state end deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index b996e9ef..6c84a8f8 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -13,17 +13,16 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do ## 2 2 0 0 ## token 1 is ambiguous ambiguous_token_id = 1 - state_transitions_tensor = Nx.tensor([[0,1,2,3], [1, 3, 0, 0], [2, 2, 0, 0 ]]) |> dbg - token_column = state_transitions_tensor[[.., ambiguous_token_id]] |> Nx.squeeze() |> dbg - {top_values, top_indices} = Nx.top_k(token_column, k: 2) |> dbg + state_transitions_tensor = Nx.tensor([[0, 1, 2, 3], [1, 3, 0, 0], [2, 2, 0, 0]]) + token_column = state_transitions_tensor[[.., ambiguous_token_id]] |> Nx.squeeze() + {top_values, top_indices} = Nx.top_k(token_column, k: 2) - ambiguous_token? = top_values[1] |> dbg + ambiguous_token? = top_values[1] - # assert(ambiguous_token?, Nx.tensor(1)) assert Nx.not_equal(ambiguous_token?, Nx.tensor(0)) - # assert Nx.tensor(0) end end + describe "suppressed_tokens_processor/3" do test "ignores the given tokens" do logits = Nx.tensor([1.0, 2.0, 3.0, 4.0]) From 8bafe3732674064f7522d826130193f6ae1d9b18 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 17 Oct 2025 09:56:02 +0200 Subject: [PATCH 16/18] last_token_id = sequence[current_length - 1] --- lib/bumblebee/text/generation/logits_processing.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index 268c68e3..7e3a8603 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -50,7 +50,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do ) do generated_length = current_length - input_length - last_token_id = sequence[current_length] + last_token_id = sequence[current_length - 1] token_column = state_transitions_tensor[[.., last_token_id]] |> Nx.squeeze() # top_k gives two top values + indices of the column From d368e4b9c958a8a08e65dfbf32648442932629b2 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 17 Oct 2025 10:23:16 +0200 Subject: [PATCH 17/18] actually fix string_token_ids definition --- pair_programming.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pair_programming.exs b/pair_programming.exs index 5146325b..38b32bab 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -56,7 +56,7 @@ special_tokens_ids = for token_id <- 0..17, do: token_id number_tokens_ids = Enum.map(numbers, &Bumblebee.Tokenizer.token_to_id(tokenizer, &1)) vocabulary_token_ids = for token_id <- 0..model_info.spec.vocab_size, do: token_id -string_token_ids = vocabulary_token_ids -- [string_token_id] -- special_tokens_ids +string_token_ids = vocabulary_token_ids -- ([string_token_id] ++ special_tokens_ids) ## sequence : 75, 33, 34, ... From 6a7675a5c92c79aeb99aa85642a34c212476a8f6 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 17 Oct 2025 10:23:26 +0200 Subject: [PATCH 18/18] remove benchmarks from pair_programming script --- pair_programming.exs | 41 ----------------------------------------- 1 file changed, 41 deletions(-) diff --git a/pair_programming.exs b/pair_programming.exs index 38b32bab..ace3c12a 100644 --- a/pair_programming.exs +++ b/pair_programming.exs @@ -125,44 +125,3 @@ serving = %{results: [_result]} = Nx.Serving.run(serving, prompt) |> dbg -# IO.puts result.text - -run_benchmarks = fn -> -serving_fn = fn max_new_tokens, dfa -> - generation_config = - Bumblebee.configure(generation_config, - max_new_tokens: max_new_tokens, - strategy: %{type: :multinomial_sampling, top_p: 0.6}, - dfa: dfa - ) - - Bumblebee.Text.generation(model_info, tokenizer, generation_config, - compile: [batch_size: 1, sequence_length: sequence_length], - stream: false, - defn_options: [compiler: Nx.Defn.Evaluator] - ) - - end - -serving_dfa_8 = serving_fn.(8, dfa) -serving_dfa_16 = serving_fn.(16, dfa) -serving_dfa_8_no_skip = serving_fn.(8, Map.delete(dfa, :ambiguous_token_ids)) -serving_dfa_16_no_skip = serving_fn.(16, Map.delete(dfa, :ambiguous_token_ids)) -serving_no_dfa_8 = serving_fn.(8, nil) -serving_no_dfa_16 = serving_fn.(16, nil) - -Benchee.run( - %{ - "max_new_tokens = 8" => fn -> Nx.Serving.run(serving_dfa_8, prompt) end, - "max_new_tokens = 16" => fn -> Nx.Serving.run(serving_dfa_16, prompt) end, - "no skip: max_new_tokens = 8" => fn -> Nx.Serving.run(serving_dfa_8_no_skip, prompt) end, - "no skip: max_new_tokens = 16" => fn -> Nx.Serving.run(serving_dfa_16_no_skip, prompt) end, - "no dfa: max_new_tokens = 8" => fn -> Nx.Serving.run(serving_no_dfa_8, prompt) end, - "no dfa: max_new_tokens = 16" => fn -> Nx.Serving.run(serving_no_dfa_16, prompt) end, - }, - time: 30, - memory_time: 2 -) -end - -