Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 61 additions & 113 deletions vllm_hpu_extension/bucketing/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, max_num_seqs, max_num_prefill_seqs, block_size,
self._setup_buckets()
self.generate_prompt_buckets()

def _num_buckets(self, value, coef, minimum=4):
return max(math.ceil(math.pow(value, coef)), minimum)

def _setup_buckets(self) -> None:
default_max_prompt_seq = 1024
default_max_decode_seq = 2048
Expand All @@ -63,25 +66,33 @@ def _setup_buckets(self) -> None:
max_decode_seq = next((item for item in [self.max_decode_seq, self.max_model_len] if item is not None), default_max_decode_seq)
max_blocks = max(
self.block_size,
self.max_num_seqs * max_decode_seq // self.block_size)
math.ceil(self.max_num_seqs * max_decode_seq / self.block_size))

# TODO: Following coeficients are completely arbitrary and were
# selected by trail and error. This is something worth tuning
# in the future
prompt_bs_coef = 0.4
prompt_seq_coef = 0.15 if get_config().prefix_caching else 0.25
decode_bs_coef = 0.4
decode_block_coef = 0.25

prompt_bs_limit = math.ceil(math.log2(self.max_num_prefill_seqs)) + 1
prompt_bs_limit = self._num_buckets(self.max_num_prefill_seqs, prompt_bs_coef)
self.global_state.prompt_bs_bucket_cfg = read_bucket_settings(
'prompt', 'bs', min=1, step=1, limit=prompt_bs_limit,
max=self.max_num_prefill_seqs)
decode_bs_limit = math.ceil(math.log2(self.max_num_seqs)) + 1
decode_bs_limit = self._num_buckets(self.max_num_seqs, decode_bs_coef)
self.global_state.decode_bs_bucket_cfg = read_bucket_settings(
'decode', 'bs', min=1, step=1, limit=decode_bs_limit,
'decode', 'bs', min=1, step=8, limit=decode_bs_limit,
max=self.max_num_seqs)
max_prompt_seq_limit = math.ceil(math.log2(max_prompt_seq)) + 1
max_prompt_seq_limit = self._num_buckets(max_prompt_seq, prompt_seq_coef)
self.global_state.prompt_seq_bucket_cfg = read_bucket_settings(
'prompt', 'seq', min=self.block_size, limit=max_prompt_seq_limit,
step=self.block_size, max=max_prompt_seq)
max_decode_block_limit = math.ceil(math.log2(max_blocks)) + 1
max_decode_block_limit = self._num_buckets(max_blocks, decode_block_coef)
self.global_state.decode_block_bucket_cfg = read_bucket_settings(
'decode', 'block', min=self.block_size, limit=max_decode_block_limit,
step=self.block_size, max=max_blocks)

msg = ("Prompt bucket config (min, step, max_warmup, limit) "
f"bs:{self.global_state.prompt_bs_bucket_cfg}, "
f"seq:{self.global_state.prompt_seq_bucket_cfg}")
Expand All @@ -95,12 +106,12 @@ def _setup_buckets(self) -> None:
def generate_prompt_buckets(self):
self.global_state.prompt_buckets, prompt_omitted_buckets = \
generate_prompt_buckets(
self.global_state.prompt_bs_bucket_cfg,
self.global_state.prompt_seq_bucket_cfg,
self.block_size,
self.prefix_caching,
self.max_num_batched_tokens,
self.max_model_len)
self.global_state.prompt_bs_bucket_cfg,
self.global_state.prompt_seq_bucket_cfg,
self.block_size,
self.prefix_caching,
self.max_num_batched_tokens,
self.max_model_len)

msg = (f"Generated {len(self.global_state.prompt_buckets)} "
f"prompt buckets [bs, seq]: "
Expand Down Expand Up @@ -139,7 +150,7 @@ def get_padded_prompt_seq_len(self, seq_len):
def get_padded_decode_num_blocks(self, num_blocks):
assert self.num_hpu_blocks is not None, "num_hpu_blocks is not set"
bucket_size = find_bucket(self.decode_buckets, num_blocks, 2)
return min(bucket_size, self.num_hpu_blocks)
return bucket_size

def get_padded_batch_size(self, batch_size, is_prompt):
if is_prompt:
Expand Down Expand Up @@ -236,20 +247,22 @@ def generate_prompt_buckets(bs_bucket_config,
max_model_len=None):
_, _, bmax, _ = seq_bucket_config
batch_size_buckets = warmup_range_with_limit(bs_bucket_config)
seq_bucket_config = warmup_range_with_limit(seq_bucket_config)
seq_buckets = warmup_range_with_limit(seq_bucket_config)

if prefix_caching:
buckets_3d = []
for bs in batch_size_buckets:
for b in seq_bucket_config:
max_blocks_range = (bmax - b) // block_size
for i in range(0, max_blocks_range + 1):
buckets_3d.append((bs, b, i))
for q_len in seq_bucket_config:
for ctx_len in [0] + seq_buckets:
if q_len + ctx_len > bmax:
break
num_ctx_blocks = math.ceil(ctx_len / block_size)
buckets_3d.append((bs, q_len, num_ctx_blocks))
buckets = buckets_3d
else:
buckets = list(
itertools.product(batch_size_buckets,
seq_bucket_config, [0]))
seq_buckets, [0]))

if len(buckets) == 0:
msg = ("No buckets could be captured with following config "
Expand Down Expand Up @@ -292,123 +305,58 @@ def generate_prompt_buckets(bs_bucket_config,
return captured_buckets, omitted_buckets


def flip_buckets(buckets):
"""Change buckets so that more values are assigned near bmax"""
bmin = buckets[0]
bmax = buckets[-1]
return [bmax - b + bmin for b in reversed(buckets)]


def update_max_blocks(cfg, kv_blocks):
prev_max = cfg[2]
new_max = min(prev_max, kv_blocks)
cfg = (*cfg[:2], new_max, *cfg[3:])
return cfg


def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
max_blocks, max_model_len, block_size,
skip_invalid=False):
buckets = []
kv_blocks, max_model_len, block_size):
bs_buckets = warmup_range_with_limit(bs_bucket_config)
tmp_blocks_bucket_config = blocks_bucket_config
tmp_blocks_bucket_config = (*tmp_blocks_bucket_config[:2], max_blocks, tmp_blocks_bucket_config[-1])
block_buckets = warmup_range_with_limit(tmp_blocks_bucket_config)
last_bucket = max_blocks
valid_blocks = set()
if not skip_invalid:
#NOTE(kzawora): this case will generate all possible combinations of
# exponentially-spaced bs and blocks, even if combination is
# invalid (exceeds max_model_len). Unfortunately, this is necessary
# to handle scenario where bucket dimensions are determined by
# get_padded_decode_num_blocks or get_padded_decode_batch_size,
# since they don't include information about the other dimension.
# This will need to be refactored at some point in the model runner,
# but for now, we are dealing with this.
valid_blocks = set((bs, 1, x) for x in sorted(block_buckets) for bs in bs_buckets)
else:
#NOTE(kzawora): this case will generate only valid combinations of
# exponentially-spaced bs and blocks, where the product of bs and blocks
# is less than or equal to max_model_len. To handle corner cases
# (e.g. longer context due to fragmentation), we're adding an additional
# bucket with max_blocks for each batch size.
# For this to work properly, bucket dimensions need be requested as
# a combination of (batch_size, num_blocks), not separately.
for bs in bs_buckets:
max_blocks_per_bs = min(bs * math.ceil(max_model_len / block_size), last_bucket)
upper_bucket_bound = next(x for x in sorted(block_buckets) if x >= max_blocks_per_bs)
valid_blocks = set((bs, 1, x) for x in sorted(block_buckets) if x <= upper_bucket_bound)

buckets.extend(list(valid_blocks))
return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))


def warmup_range_with_limit(config: Tuple[int, int, int, int], fill=True):
blocks_bucket_config = update_max_blocks(blocks_bucket_config, kv_blocks)
block_buckets = warmup_range_with_limit(blocks_bucket_config)
block_buckets = flip_buckets(block_buckets)
return [(bs, 1, x) for x in sorted(block_buckets) for bs in bs_buckets]


def warmup_range_with_limit(config: Tuple[int, int, int, int]):
"""
NOTE(kzawora): we'll use exponential spacing for buckets in which scaled
power will return bmin for first bucket iteration, and bmax for last
iteration, with elements between determined by the exponent, and base being
unchanged. Note that after padding to bstep, duplicates may occur.
Handling of duplicates is configured by fill parameter.
If fill is False, duplicates are removed and less buckets are returned.
Duplicates are removed and less buckets are returned.

If fill is True, duplicates are resolved by selecting the closest (larger
or smaller) bucket. If duplicate resolution is not possible, less buckets
are returned. In that case, buckets are guaranteed to be linearly spaced.
Example (bmin=128, bstep=128, bmax=2048, num_buckets=10):
There are 16 possible buckets (2048/128), and we'll attempt to select 10 of
them with exponential spacing.
base = (bmax/bmin) ** (1/(num_buckets-1)); (2048/128) ** (1/9) = 1.36079
exponent = i
power = base ** exponent
scaled_power = b_min * power
For i == 0 (first bucket), power is 1.36079 ** 0 = 1;
scaled_power is 1 * 128 = 128 (==bmin)
For i == 9 (last bucket), power is 1.36079 ** 9 = 16;
scaled_power is 16 * 128 = 2048 (==bmax)
So, computing for all buckets:
scaled_powers_unpadded = [bmin*base^0(==bmin), bmin*base^1, bmin*base^2, ..., bmin*base^9(==bmax)]
scaled_powers_unpadded = [128.00, 174.18, 237.02, 322.54, 438.91, 597.26, 812.75, 1105.98, 1505.01, 2048.00]

if fill is False:
scaled_powers_padded = [ 128, 256, 256, 384, 512, 640, 896, 1152, 1536, 2048]
^_______^
duplicates
buckets = [ 128, 256, 384, 512, 640, 896, 1152, 1536, 2048]
^
duplicate bucket removed
len(buckets) = 9, num_buckets = 10
if fill is True:
buckets = [ 128, 256, 384, 512, 640, 768, 896, 1152, 1536, 2048]
^_______^_______^_______^
closest unused buckets selected
^_______^_______^
these become duplicates once previous duplicates are resolved

In this case we'll have four duplicated buckets:
174.18 -> 256, optimal bucket,
237.02 -> (256) -> 384, taking closest available bucket,
as optimal bucket 256 was already captured by 174.18,
322.54 -> (384) -> 512, taking closest available bucket,
as optimal bucket 384 was already captured by 237.02,
438.91 -> (512) -> 640, taking closest available bucket,
as optimal bucket 512 was already captured by 322.54,
597.26 -> (640) -> 768, taking closest available bucket,
as optimal bucket 640 was already captured by 438.91,
812.75 -> 896, optimal bucket
len(buckets) = 10, num_buckets = 10
In this case, the end result has the same buckets as fill=False,
but with additional bucket 768 added.
The difference is more pronounced for larger ranges and larger number
of buckets.
""" # noqa: E501

bmin, bstep, bmax, num_buckets = config
linear_buckets = set(np.arange(bmin, bmax + 1, step=bstep))
assert num_buckets > 0, "num_buckets must be a positive integer"
if num_buckets == 1:
return [bmax]
buckets: Set[Tuple[int, int]] = set()
for i in range(num_buckets):
power_unpadded = bmin * np.float_power(
bmax / bmin, (1. / float(num_buckets - 1)) * i)
if i == num_buckets - 1 and get_config().use_contiguous_pa:
bucket = bmax
else:
power_unpadded = bmin * np.float_power(
bmax / bmin, (1. / float(num_buckets - 1)) * i)
bucket = math.ceil(power_unpadded / bstep) * bstep
if fill and bucket in buckets:
available_buckets = linear_buckets.difference(buckets)
if len(available_buckets) == 0:
break # there are no more unique buckets, let's exit now
new_bucket = min(available_buckets,
key=lambda x: abs(x - power_unpadded))
buckets.add(new_bucket)
else:
buckets.add(bucket)
return list(sorted(buckets))
buckets.add(bucket)
buckets = list(sorted(buckets))
return buckets