diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 8f18f0144193..95eeabab9006 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -4,7 +4,6 @@ import random from copy import deepcopy from dataclasses import dataclass -from unittest.mock import patch import pytest import torch @@ -136,7 +135,6 @@ def populate_loras( id_to_index: list[int | None], layer: BaseLayerWithLoRA, layer_weights: torch.Tensor, - generate_embeddings_tensor: int = 0, repeats: int = 1, ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: """This method populates the lora layers with lora weights. @@ -148,8 +146,6 @@ def populate_loras( layer: the LoRAlayer to populate. layer_weights: the PyTorch tensor containing the layer's weights. - generate_embeddings_tensor: whether to generate an - embeddings tensor for each LoRA. repeats: must only be set for column parallel packed layers. Indicates the number of loras to compose together to create a single lora layer. @@ -171,7 +167,6 @@ def populate_loras( sublora = DummyLoRAManager(layer_weights.device).init_random_lora( module_name=f"fake_{i}", weight=layer_weights, - generate_embeddings_tensor=generate_embeddings_tensor, ) sublora.lora_b = sublora.lora_b[ (sublora_len * i) : (sublora_len * (i + 1)), : @@ -185,7 +180,6 @@ def populate_loras( slot_idx, lora_a=lora.lora_a, lora_b=lora.lora_b, - embeddings_tensor=lora.embeddings_tensor, ) lora_dict[lora_id] = lora @@ -306,7 +300,6 @@ def create_random_embedding_layer(): id_to_index, max_loras, vocab_size, - lora_config.lora_extra_vocab_size, ) lora_result = lora_embedding(torch.cat(inputs)) @@ -344,7 +337,6 @@ def create_random_embedding_layer(): id_to_index, max_loras, vocab_size, - lora_config.lora_extra_vocab_size, ) lora_result = lora_embedding(torch.cat(inputs)) @@ -354,149 +346,6 @@ def create_random_embedding_layer(): torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) -@torch.inference_mode() -# @pytest.mark.skip( -# reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4]) -@pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -@pytest.mark.parametrize("stage", STAGES) -def test_embeddings_with_new_embeddings( - dist_init, num_loras, device, vocab_size, stage -) -> None: - if current_platform.is_cuda_alike(): - torch.cuda.set_device(device) - - torch.set_default_device(device) - max_loras = 8 - punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) - assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig( - max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 - ) - - def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(vocab_size, 256) - embedding_data = torch.rand_like(embedding.weight.data) - embedding.weight.data = embedding_data - embedding.weight.data[vocab_size:, :] = 0 - expanded_embedding = VocabParallelEmbedding( - vocab_size + lora_config.lora_extra_vocab_size * max_loras, - 256, - org_num_embeddings=vocab_size, - ) - expanded_embedding.weight.data[:vocab_size, :] = embedding_data - # We need to deepcopy the embedding as it will be modified - # in place - lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding)) - lora_embedding.create_lora_weights(max_loras, lora_config) - - return expanded_embedding, lora_embedding - - for i in range(NUM_RANDOM_SEEDS): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - expanded_embedding, lora_embedding = create_random_embedding_layer() - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_embedding, - layer_weights=torch.zeros( - (256, vocab_size + lora_config.lora_extra_vocab_size) - ), - generate_embeddings_tensor=256, - ) - - lora_embedding.set_mapping(punica_wrapper) - # All embeddings tensors have the same shape. - embeddings_tensors = [ - lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) - ] - embeddings_tensor_len = embeddings_tensors[0].shape[0] - - # Add empty embeddings_tensors for unoccupied lora slots. - for _ in range(max_loras - len(embeddings_tensors)): - embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape)) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=num_loras * 3, - input_size=(200,), - input_range=(1, vocab_size), - device=device, - ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) - original_inputs = deepcopy(inputs) - - # Force some of the inputs to be in the extended embeddings range - # to guarantee that their behavior is tested. - for input_, original_input_, lora_id in zip( - inputs, original_inputs, prompt_mapping - ): - embedding_id = lora_id - 1 - input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) - original_input_[-1] = vocab_size - input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1) - original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - - expanded_embedding.weight[ - vocab_size : vocab_size + (embeddings_tensor_len * max_loras) - ] = torch.cat(embeddings_tensors) - - lora_result = lora_embedding(torch.cat(original_inputs)) - - expected_results: list[torch.Tensor] = [] - for input_, original_input_, lora_id in zip( - inputs, original_inputs, prompt_mapping - ): - lora = lora_dict[lora_id] - result = expanded_embedding(input_) - after_a = F.embedding( - original_input_, - lora.lora_a.T, - ) - result += after_a @ lora.lora_b.T - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_embedding.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=num_loras * 3, - input_size=(200,), - input_range=(1, vocab_size), - device=device, - ) - original_inputs = deepcopy(inputs) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) - lora_result = lora_embedding(torch.cat(original_inputs)) - expected_result = expanded_embedding(torch.cat(inputs)) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) - - @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @@ -518,16 +367,13 @@ def test_lm_head_logits_processor( def _pretest(): linear = ParallelLMHead( - vocab_size + lora_config.lora_extra_vocab_size, - 1024, - vocab_size, + num_embeddings=vocab_size, + embedding_dim=1024, params_dtype=torch.float16, ) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, vocab_size:] = 0 - logits_processor = LogitsProcessor( - vocab_size + lora_config.lora_extra_vocab_size, vocab_size - ) + logits_processor = LogitsProcessor(vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( logits_processor, 1024, linear.weight.dtype, linear.weight.device, None ) @@ -541,15 +387,12 @@ def _pretest(): id_to_index = get_random_id_to_index(num_loras, max_loras) linear, logits_processor, lora_logits_processor = _pretest() lora_logits_processor.set_mapping(punica_wrapper) - # NOTE: all the generated loras share the same embeddings tensor. + lora_dict, _ = populate_loras( id_to_index, layer=lora_logits_processor, layer_weights=linear.weight, - generate_embeddings_tensor=1024, ) - embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor - embeddings_tensor_len = embeddings_tensor.shape[0] inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), @@ -565,7 +408,6 @@ def _pretest(): id_to_index, max_loras, vocab_size, - lora_config.lora_extra_vocab_size, ) input_ = torch.rand(20, 1024) @@ -575,23 +417,16 @@ def _pretest(): original_lm_head = deepcopy(linear) - linear.weight[ - logits_processor.org_vocab_size : logits_processor.org_vocab_size - + embeddings_tensor_len - ] = embeddings_tensor - - logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = logits_processor._get_logits( hidden_states=input_, lm_head=linear, embedding_bias=None ) - result[:, vocab_size + embeddings_tensor_len :] = float("-inf") + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) - logits_processor.org_vocab_size = vocab_size # Check that resetting the lora weights succeeds @@ -612,7 +447,6 @@ def _pretest(): id_to_index, max_loras, vocab_size, - lora_config.lora_extra_vocab_size, ) lora_result = lora_logits_processor._get_logits( @@ -694,7 +528,6 @@ def create_random_linear_replicated_layer(): id_to_index, max_loras, 512, - lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -726,7 +559,10 @@ def create_random_linear_replicated_layer(): lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( - lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size + lora_mapping, + id_to_index, + max_loras, + 512, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -817,7 +653,6 @@ def create_random_linear_parallel_layer(): id_to_index, max_loras, 512, - lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -849,7 +684,10 @@ def create_random_linear_parallel_layer(): lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( - lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size + lora_mapping, + id_to_index, + max_loras, + 512, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -963,7 +801,6 @@ class FakeConfig: id_to_index, max_loras, 512, - lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -1000,7 +837,6 @@ class FakeConfig: id_to_index, max_loras, 512, - lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -1010,109 +846,6 @@ class FakeConfig: torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) -@pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) -@pytest.mark.parametrize( - "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)) -) -def test_vocab_parallel_embedding_indices(tp_size, seed): - random.seed(seed) - vocab_size = random.randint(4000, 64000) - added_vocab_size = random.randint(0, 1024) - org_vocab_size = vocab_size - added_vocab_size - last_org_vocab_end_index = 0 - last_added_vocab_end_index = org_vocab_size - computed_vocab_size = 0 - computed_org_vocab_size = 0 - computed_added_vocab_size = 0 - vocab_size_padded = -1 - - all_org_tokens: list[int] = [] - all_added_tokens: list[int] = [] - token_ids: list[int] = [] - - for tp_rank in range(tp_size): - with ( - patch( - "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", - return_value=tp_rank, - ), - patch( - "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", - return_value=tp_size, - ), - ): - vocab_embedding = VocabParallelEmbedding( - vocab_size, 1, org_num_embeddings=org_vocab_size - ) - vocab_size_padded = vocab_embedding.num_embeddings_padded - shard_indices = vocab_embedding.shard_indices - # Assert that the ranges are contiguous - assert shard_indices.org_vocab_start_index == last_org_vocab_end_index - assert shard_indices.added_vocab_start_index == last_added_vocab_end_index - - # Ensure that we are not exceeding the vocab size - computed_vocab_size += shard_indices.num_elements_padded - computed_org_vocab_size += shard_indices.num_org_elements - computed_added_vocab_size += shard_indices.num_added_elements - - # Ensure that the ranges are not overlapping - all_org_tokens.extend( - range( - shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index - ) - ) - all_added_tokens.extend( - range( - shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index, - ) - ) - - token_ids.extend( - range( - shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index - ) - ) - token_ids.extend( - [-1] - * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements) - ) - token_ids.extend( - range( - shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index, - ) - ) - token_ids.extend( - [-1] - * ( - shard_indices.num_added_elements_padded - - shard_indices.num_added_elements - ) - ) - - last_org_vocab_end_index = shard_indices.org_vocab_end_index - last_added_vocab_end_index = shard_indices.added_vocab_end_index - - assert computed_vocab_size == vocab_size_padded - assert computed_org_vocab_size == org_vocab_size - assert computed_added_vocab_size == added_vocab_size - - # Ensure that the ranges are not overlapping - assert len(all_org_tokens) == len(set(all_org_tokens)) - assert len(all_added_tokens) == len(set(all_added_tokens)) - assert not set(all_org_tokens).intersection(set(all_added_tokens)) - - token_ids_tensor = torch.tensor(token_ids, dtype=torch.long) - reindex_mapping = vocab_embedding.get_sharded_to_full_mapping() - assert reindex_mapping is not None or tp_size == 1 - if reindex_mapping is not None: - reindexed_token_ids = token_ids_tensor[reindex_mapping] - expected = torch.tensor(list(range(0, vocab_size))) - assert reindexed_token_ids[:vocab_size].equal(expected) - assert torch.all(reindexed_token_ids[vocab_size:] == -1) - - def test_get_masked_input_and_mask(): x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index e7816031142e..24d4dfca46d6 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -48,9 +48,6 @@ @pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file(os.path.join(sql_lora_files, "adapter_model.safetensors")) - new_embeddings = load_file( - os.path.join(sql_lora_files, "new_embeddings.safetensors") - ) peft_helper = PEFTHelper.from_local_dir( sql_lora_files, max_position_embeddings=4096 @@ -60,7 +57,6 @@ def test_from_lora_tensors(sql_lora_files, device): tensors, peft_helper=peft_helper, device=device, - embeddings=new_embeddings, embedding_modules=EMBEDDING_MODULES, embedding_padding_modules=EMBEDDING_PADDING_MODULES, ) @@ -76,18 +72,6 @@ def test_from_lora_tensors(sql_lora_files, device): f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" ) assert lora.lora_a.shape[0] == 8 - embeddings_module = next( - (k for k in EMBEDDING_MODULES if k in module_name), None - ) - if embeddings_module: - assert torch.equal( - lora.embeddings_tensor, - new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( - device=lora.embeddings_tensor.device - ), - ) - else: - assert lora.embeddings_tensor is None def create_lora( @@ -552,9 +536,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path worker_adapter_manager = WorkerLoRAManager( vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES ) - worker_adapter_manager.vocab_size = ( - dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size - ) + worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size worker_adapter_manager.create_lora_manager(dummy_model_gate_up) dummy_lora_files = f"{tmp_path}/lora_adapter" diff --git a/tests/lora/utils.py b/tests/lora/utils.py index d30b77f09466..6aba5299b582 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -28,7 +28,6 @@ def init_random_lora( module_name: str, weight: torch.Tensor, rank: int = 8, - generate_embeddings_tensor: int = 0, ): lora = LoRALayerWeights( module_name, @@ -41,13 +40,6 @@ def init_random_lora( [weight.shape[0], rank], dtype=weight.dtype, device=self._device ), ) - if generate_embeddings_tensor: - lora.embeddings_tensor = torch.rand( - 5, - generate_embeddings_tensor, - dtype=weight.dtype, - device=self._device, - ) self.set_module_lora(module_name, lora) return lora diff --git a/vllm/config/lora.py b/vllm/config/lora.py index 84e92eef4007..072e0ec2104f 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from typing import TYPE_CHECKING, Any, ClassVar, Literal +from typing import TYPE_CHECKING, Any, Literal import torch from pydantic import ConfigDict, Field, model_validator @@ -11,7 +11,6 @@ from vllm.config.utils import config from vllm.logger import init_logger -from vllm.platforms import current_platform if TYPE_CHECKING: from vllm.config import ModelConfig @@ -46,19 +45,6 @@ class LoRAConfig: `max_loras`.""" lora_dtype: torch.dtype | LoRADType = "auto" """Data type for LoRA. If auto, will default to base model dtype.""" - lora_extra_vocab_size: LoRAExtraVocabSize = Field( - default=256, - deprecated=( - "`lora_extra_vocab_size` is deprecated and will be removed " - "in v0.12.0. Additional vocabulary support for " - "LoRA adapters is being phased out." - ), - ) - """(Deprecated) Maximum size of extra vocabulary that can be present in a - LoRA adapter. Will be removed in v0.12.0.""" - lora_vocab_padding_size: ClassVar[int] = ( - current_platform.get_lora_vocab_padding_size() - ) default_mm_loras: dict[str, str] | None = None """Dictionary mapping specific modalities to LoRA model paths; this field is only applicable to multimodal models and should be leveraged when a @@ -87,8 +73,6 @@ def compute_hash(self) -> str: factors.append(self.max_loras) factors.append(self.fully_sharded_loras) factors.append(self.lora_dtype) - factors.append(self.lora_extra_vocab_size) - factors.append(self.lora_vocab_padding_size) hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b025004ea022..d34d6498a404 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -477,7 +477,6 @@ class EngineArgs: fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: int | None = LoRAConfig.max_cpu_loras lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype - lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override @@ -984,9 +983,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) - lora_group.add_argument( - "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"] - ) lora_group.add_argument( "--lora-dtype", **lora_kwargs["lora_dtype"], @@ -1582,7 +1578,6 @@ def create_engine_config( max_loras=self.max_loras, default_mm_loras=self.default_mm_loras, fully_sharded_loras=self.fully_sharded_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py index 0c7e80684889..62326c05b2bd 100644 --- a/vllm/lora/layers/base.py +++ b/vllm/lora/layers/base.py @@ -44,7 +44,6 @@ def set_lora( index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, ): """Overwrites lora tensors at index.""" ... diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index d619a0edc124..890ccd65ec81 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -96,7 +96,6 @@ def set_lora( index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, ): # Except for QKVParallelLinearWithLoRA and # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 637ded9b2a0f..273c4950e323 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -248,7 +248,6 @@ def set_lora( index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, ): self.reset_lora(index) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 8fb3efa220f6..bb3ccf9ec7a8 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -392,8 +392,6 @@ def set_lora( index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, - bias: torch.Tensor | None = None, ): """Overwrites lora tensors at index.""" self.reset_lora(index) diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index adc5e861f57f..06f92652031e 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math import torch import torch.nn as nn @@ -108,22 +107,13 @@ def create_lora_weights( ( max_loras, 1, - # Pad for kernel compatibility - math.ceil( - self.base_layer.vocab_size / lora_config.lora_vocab_padding_size - ) - * lora_config.lora_vocab_padding_size, + self.base_layer.vocab_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, device=self.device, ) - self.embeddings_tensors = torch.full( - (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), - fill_value=float("-inf"), - dtype=self.dtype, - device=self.device, - ) + if self.sharded_to_full_mapping is not None: self.sharded_to_full_mapping_gpu = torch.tensor( self.sharded_to_full_mapping, device=self.device, dtype=torch.long @@ -134,14 +124,12 @@ def create_lora_weights( def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = float("-inf") def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, ): self.reset_lora(index) self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( @@ -150,12 +138,6 @@ def set_lora( self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( lora_b, non_blocking=True ) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - : embeddings_tensor.shape[0], - : embeddings_tensor.shape[1], - ] = embeddings_tensor def _get_logits( self, @@ -193,39 +175,6 @@ def _get_logits( # token_id: [0, 1, 2, 3, 4, 5, -1, -1] logits = logits[:, self.sharded_to_full_mapping_gpu] - lora_logits = torch.empty( - self.embeddings_tensors.shape[0] + 1, - self.embeddings_tensors.shape[1], - hidden_states.shape[0], - dtype=self.embeddings_tensors.dtype, - device=self.embeddings_tensors.device, - ) - torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) - - neg_inf, pos_inf = current_platform.get_infinity_values(lora_logits.dtype) - - lora_logits[-1] = neg_inf - lora_logits = lora_logits.mT - indices_padded = self.punica_wrapper.sampler_indices_padded - - if current_platform.is_tpu() or current_platform.is_xpu(): - indices_padded = indices_padded[: logits.size(0)] - - lora_logits = ( - lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], - ) - .index_select(0, indices_padded) - .nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf) - ) - - logits[ - :, - self.base_layer.org_vocab_size : self.base_layer.org_vocab_size - + lora_logits.shape[1], - ] = lora_logits - lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits( logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0 ) diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index ca4ad8012e9c..a3e18fbc1c02 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -46,19 +46,10 @@ def create_lora_weights( self.embeddings_slice = None self.embeddings_weights = None - self.embeddings_tensors = torch.zeros( - ( - max_loras, - lora_config.lora_extra_vocab_size, - self.base_layer.embedding_dim, - ), - dtype=self.base_layer.weight.dtype, - device=self.base_layer.weight.device, - ) self.lora_a_stacked = torch.zeros( ( max_loras, - self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size, + self.base_layer.org_vocab_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -82,14 +73,12 @@ def create_lora_weights( def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = 0 def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, ): self.reset_lora(index) # NOTE self.lora_a_stacked is row-major, and lora_a is col-major, @@ -100,36 +89,16 @@ def set_lora( self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( lora_b, non_blocking=True ) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - : embeddings_tensor.shape[0], - : embeddings_tensor.shape[1], - ].copy_(embeddings_tensor, non_blocking=True) - if self.embeddings_slice is not None: - # TODO(yard1): Optimize this copy, we don't need to copy - # everything, just the modified part - embeddings = self.embeddings_tensors.view( - self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2], - )[self.embeddings_slice[0] : self.embeddings_slice[1]] - assert self.embeddings_weights is not None - self.embeddings_weights[: embeddings.shape[0]].copy_(embeddings) def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) - # NB: Don't use torch.narrow here. torch.narrow triggers some # Dynamic Shape specialization in torch.compile - num_tokens = x.shape[0] - indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] - indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] full_lora_a_embeddings = F.embedding( - x + indices_1, + x, self.lora_a_stacked_2d, ) - full_output = self.base_layer.forward(x + (indices_0 * added_tokens_mask)) + full_output = self.base_layer.forward(x) full_output_org = full_output if full_output.ndim == 3: diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py index 7691481d5039..f0d8e2219405 100644 --- a/vllm/lora/lora_weights.py +++ b/vllm/lora/lora_weights.py @@ -21,7 +21,6 @@ def __init__( lora_alpha: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None = None, scaling: float | None = None, ) -> None: self.module_name = module_name @@ -29,7 +28,6 @@ def __init__( self.lora_alpha = lora_alpha self.lora_a = lora_a self.lora_b = lora_b - self.embeddings_tensor = embeddings_tensor if scaling is None: self.scaling = self.lora_alpha / self.rank @@ -56,18 +54,11 @@ def output_dim(self) -> int: def is_packed(self) -> bool: return False - @property - def extra_vocab_size(self) -> int: - return ( - self.embeddings_tensor.shape[0] if self.embeddings_tensor is not None else 0 - ) - @classmethod def from_config( cls, module_name: str, peft_helper: PEFTHelper, - embeddings_tensor: torch.Tensor | None = None, ) -> "LoRALayerWeights": # lora_a and lora_b are set to None for config-based construction return cls( @@ -76,7 +67,6 @@ def from_config( peft_helper.lora_alpha, None, None, - embeddings_tensor, peft_helper.vllm_lora_scaling_factor, ) @@ -89,7 +79,6 @@ def create_dummy_lora_weights( rank: int, dtype: torch.dtype, device: torch.types.Device, - embeddings_tensor_dim: int | None = None, ) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() lora_a = torch.zeros( @@ -99,24 +88,12 @@ def create_dummy_lora_weights( [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory ) - embeddings_tensor = ( - torch.rand( - 10, - embeddings_tensor_dim, - dtype=dtype, - device=device, - pin_memory=pin_memory, - ) - if embeddings_tensor_dim - else None - ) return cls( module_name, rank=rank, lora_alpha=1, lora_a=lora_a, lora_b=lora_b, - embeddings_tensor=embeddings_tensor, ) @@ -139,7 +116,6 @@ def __init__( lora_a=lora_a, lora_b=lora_b, scaling=scaling, # type: ignore - embeddings_tensor=None, ) self.lora_alphas = lora_alphas if scaling is None: diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 02c252f15bfa..0ef6919775fa 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -93,14 +93,6 @@ def clone(self, lora_model_id: int) -> "LoRAModel": loras=self.loras.copy(), ) - @property - def extra_vocab_size(self) -> int: - return ( - max(lora.extra_vocab_size for lora in self.loras.values()) - if self.loras - else 0 - ) - def get_lora(self, module_name: str) -> LoRALayerWeights | None: """Get LoRA for a given module by name""" return self.loras.get(module_name, None) @@ -117,7 +109,6 @@ def from_lora_tensors( peft_helper: PEFTHelper, device: str = "cuda", dtype: torch.dtype | None = None, - embeddings: dict[str, torch.Tensor] | None = None, target_embedding_padding: int | None = None, embedding_modules: dict[str, str] | None = None, embedding_padding_modules: list[str] | None = None, @@ -131,20 +122,8 @@ def from_lora_tensors( tensor_name, weights_mapper ) if module_name not in loras: - lora_embeddings_tensor = None - if embeddings: - assert embedding_modules is not None - embeddings_module = next( - (k for k in embedding_modules if k in module_name), None - ) - if embeddings_module: - lora_embeddings_tensor = embeddings[ - embedding_modules[embeddings_module] - ].to(device=device, dtype=dtype) - if pin_memory: - lora_embeddings_tensor = lora_embeddings_tensor.pin_memory() loras[module_name] = LoRALayerWeights.from_config( - module_name, peft_helper, lora_embeddings_tensor + module_name, peft_helper ) if is_lora_a: @@ -206,10 +185,10 @@ def from_local_checkpoint( lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") - new_embeddings_tensor_path = os.path.join( - lora_dir, "new_embeddings.safetensors" - ) - new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") + # new_embeddings_tensor_path = os.path.join( + # lora_dir, "new_embeddings.safetensors" + # ) + # new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") tensors: dict[str, torch.Tensor] = {} unexpected_modules: list[list[str] | str] = [] @@ -300,21 +279,12 @@ def check_unexpected_modules(modules: dict): else: raise ValueError(f"{lora_dir} doesn't contain tensors") - embeddings = None - if os.path.isfile(new_embeddings_tensor_path): - embeddings = safetensors.torch.load_file(new_embeddings_tensor_path) - elif os.path.isfile(new_embeddings_bin_file_path): - embeddings = torch.load( - new_embeddings_bin_file_path, map_location=device, weights_only=True - ) - return cls.from_lora_tensors( lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, tensors=tensors, peft_helper=peft_helper, device=device, dtype=dtype, - embeddings=embeddings, target_embedding_padding=target_embedding_padding, embedding_modules=embedding_modules, embedding_padding_modules=embedding_padding_modules, @@ -474,7 +444,6 @@ def activate_adapter( index, module_lora.lora_a, module_lora.lora_b, - module_lora.embeddings_tensor, ) else: module.reset_lora(index) @@ -505,7 +474,6 @@ def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: self.lora_index_to_id, self.lora_slots + 1, self.vocab_size, - self.lora_config.lora_extra_vocab_size, ) def remove_all_adapters(self): @@ -616,7 +584,6 @@ def create_dummy_lora( if parts[-1] in embedding_modules: input_dim = ( module.base_layer.org_vocab_size - + self.lora_config.lora_extra_vocab_size if hasattr(module.base_layer, "org_vocab_size") else module.base_layer.weight.shape[1] ) @@ -625,11 +592,6 @@ def create_dummy_lora( if hasattr(module.base_layer, "embedding_dim") else module.base_layer.weight.shape[0] ) - embeddings_tensor_dim = ( - module.base_layer.embedding_dim - if hasattr(module.base_layer, "embedding_dim") - else module.base_layer.weight.shape[1] - ) lora = LoRALayerWeights.create_dummy_lora_weights( module_name, input_dim, @@ -637,7 +599,6 @@ def create_dummy_lora( rank, module.lora_a_stacked[0].dtype, "cpu", - embeddings_tensor_dim=embeddings_tensor_dim, ) else: lora = LoRALayerWeights.create_dummy_lora_weights( diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index b6186e856152..6564ce2e3fa6 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -31,7 +31,6 @@ def update_metadata( lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, **kwargs, ) -> None: """ @@ -145,14 +144,9 @@ def __init__( self._sampler_indices_padded = torch.empty( max_num_batched_tokens, dtype=torch.long, device=device ) - self._embeddings_indices = torch.empty( - 2, max_num_batched_tokens, dtype=torch.long, device=device - ) - - # 4 is the number of indices tensors. - # base_indices, sampler_indices, sampler_indices_padded, - # embeddings_indices - self.indices_len: list[int | None] = [None] * 4 + # 3 is the number of indices tensors. + # base_indices, sampler_indices, sampler_indices_padded + self.indices_len: list[int | None] = [None] * 3 # these attributes are the information required for sgmv kernel self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, device=device) self._seq_lengths = torch.empty(max_batches, dtype=torch.long, device=device) @@ -172,20 +166,17 @@ def _update_base_metadata( lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, ): ( base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, indices_len, ) = convert_mapping( mapping, lora_index_to_id, max_loras, vocab_size, - extra_vocab_size, self.device, ) self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices) @@ -193,10 +184,6 @@ def _update_base_metadata( self._sampler_indices_padded[: sampler_indices_padded.shape[0]].copy_( sampler_indices_padded ) - self._embeddings_indices[ - : embeddings_indices.shape[0], : embeddings_indices.shape[1] - ].copy_(embeddings_indices) - self.indices_len[:] = indices_len def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: @@ -270,27 +257,15 @@ def sampler_indices_padded(self) -> torch.Tensor: indices_padded_len = self.indices_len[2] return self._sampler_indices_padded[:indices_padded_len] - @property - def embeddings_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA. - """ - embeddings_indices_len = self.indices_len[3] - return self._embeddings_indices[:, :embeddings_indices_len] - def update_metadata( self, mapping: "LoRAMapping", lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, **kwargs, ): - self._update_base_metadata( - mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size - ) + self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size) if mapping.is_prefill: # Update metadata required for prefill-related operators. diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index ede50a48af98..19301079e00b 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -65,13 +65,10 @@ def update_metadata( lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, **kwargs, ): self.is_prefill = mapping.is_prefill - self._update_base_metadata( - mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size - ) + self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size) # Prepare cuda kernel metadata tensors self.token_mapping_meta.prepare_tensors(self.token_lora_indices) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 090878dcd254..b41ca7d4790c 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -46,24 +46,14 @@ def __init__( torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True) torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True) torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True) torch._dynamo.mark_dynamic(self._token_lora_indices, 0) - torch._dynamo.mark_dynamic(self._embeddings_indices, 1) torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) - @property - def embeddings_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA. - """ - return self._embeddings_indices[:] - @property def sampler_indices_padded(self) -> torch.Tensor: """ @@ -292,7 +282,6 @@ def _update_base_metadata( lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, ): # Make sure we don't accidentally collect outside operations torch_xla.sync() @@ -306,14 +295,12 @@ def _update_base_metadata( base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, indices_len, ) = convert_mapping( mapping, lora_index_to_id, max_loras, vocab_size, - extra_vocab_size, "cpu", ) self._token_lora_indices = self._pad_to_shape( @@ -325,9 +312,6 @@ def _update_base_metadata( self._sampler_indices_padded = self._pad_to_shape( sampler_indices_padded, self._sampler_indices_padded.shape, dims=1 ).to(self.device) - self._embeddings_indices = self._pad_to_shape( - embeddings_indices, self._embeddings_indices.shape, dims=2 - ).to(self.device) self.indices_len[:] = indices_len def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index b95087d0ff83..bdf71c6a11a1 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -34,7 +34,6 @@ def __init__( ): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) torch._dynamo.mark_dynamic(self._token_lora_indices, 0) - torch._dynamo.mark_dynamic(self._embeddings_indices, 1) torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) def update_metadata( @@ -43,13 +42,10 @@ def update_metadata( lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, **kwargs, ): self.is_prefill = mapping.is_prefill - self._update_base_metadata( - mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size - ) + self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size) def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index 584745f86b1a..71c55fd0a0c5 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -54,9 +54,8 @@ def convert_mapping( lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, - extra_vocab_size: int, device: torch.device, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[int]]: """Converts LoRAMapping to index tensors. Args: @@ -64,7 +63,7 @@ def convert_mapping( lora_index_to_id: List mapping LoRA ids to LoRA indices. max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. - extra_vocab_size: Extra vocab size each LoRA can have. + device: Device to put the tensors on. Returns: A tuple of tensors: @@ -78,13 +77,8 @@ def convert_mapping( requests to LoRA indices for sampler with padding. Same as sampler_indices, but -1 is replaced with max_loras. - embeddings_indices: Tensor of shape [2, batch_size] mapping - requests to embedding indices. First row is for embeddings - added by the LoRAs, second row is for the LoRA.lora_a - embeddings. indices_len: List of lengths of the above tensors. It contains - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices). + (base_indices, sampler_indices,embeddings_indices). """ index_mapping_indices: list[int] = list(mapping.index_mapping).copy() embedding_indices = index_mapping_indices.copy() @@ -114,15 +108,7 @@ def convert_mapping( prompt_mapping_tensor = torch.tensor( prompt_mapping, dtype=torch.long, device=device ) - embeddings_indices = torch.stack( - [ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size), - ] - ) - embeddings_indices = torch.where( - embeddings_indices == -1, max_loras - 1, embeddings_indices - ) + base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() @@ -138,13 +124,11 @@ def convert_mapping( base_indices.shape[-1], sampler_indices.shape[-1], sampler_indices_padded.shape[-1], - embeddings_indices.shape[-1], ] return ( base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, indices_len, ) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index b85151f2c759..4cc201a6414f 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -121,8 +121,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, - target_embedding_padding=self.vocab_size - + self.lora_config.lora_extra_vocab_size, + target_embedding_padding=self.vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, tensorizer_config_dict=lora_request.tensorizer_config_dict, @@ -143,12 +142,6 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: # For BadRequestError raise e - if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: - raise ValueError( - f"LoRA added vocab size {lora.extra_vocab_size} " - f"is greater than lora_extra_vocab_size " - f"{self.lora_config.lora_extra_vocab_size}." - ) return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index c44b4021471e..172bbda12bbd 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -47,7 +47,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -276,29 +275,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) else: @@ -435,28 +421,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.model = GraniteModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -468,7 +444,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale /= config.logits_scaling self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, scale=logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c49a1ea817f9..7444c2486323 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,7 +48,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -386,24 +385,18 @@ def __init__( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -578,9 +571,7 @@ def __init__( super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.model = self._init_model( vllm_config=vllm_config, @@ -589,20 +580,9 @@ def __init__( ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -611,7 +591,7 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d7a1cb82fb4f..30c701ae714e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -51,7 +51,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -306,23 +305,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.enable_eplb = parallel_config.enable_eplb @@ -513,34 +507,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.model = MixtralModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/teleflm.py b/vllm/model_executor/models/teleflm.py index 8a0bec9dff84..bebd7bcaa924 100644 --- a/vllm/model_executor/models/teleflm.py +++ b/vllm/model_executor/models/teleflm.py @@ -74,5 +74,5 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.output_mult = self.config.output_mult / self.mup_scale_factor logit_scale = self.output_mult self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.config.vocab_size, logit_scale + self.config.vocab_size, scale=logit_scale )