diff --git a/README.md b/README.md index aeff0e00..ee1a7fb5 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,8 @@ This package integrates Large Language Models (LLMs) into [spaCy](https://spacy. - **[OpenLLaMA](https://huggingface.co/openlm-research)** - **[StableLM](https://huggingface.co/stabilityai)** - **[Mistral](https://huggingface.co/mistralai)** + - **[Zephyr](https://huggingface.co/HuggingFaceH4)** + - **[Yi](https://huggingface.co/01-ai)** - Integration with [LangChain](https://github.com/hwchase17/langchain) 🦜️🔗 - all `langchain` models and features can be used in `spacy-llm` - Tasks available out of the box: - Named Entity Recognition diff --git a/spacy_llm/models/hf/__init__.py b/spacy_llm/models/hf/__init__.py index b3afbb71..c683781e 100644 --- a/spacy_llm/models/hf/__init__.py +++ b/spacy_llm/models/hf/__init__.py @@ -5,6 +5,8 @@ from .mistral import mistral_hf from .openllama import openllama_hf from .stablelm import stablelm_hf +from .yi import yi_hf +from .zephyr import zephyr_hf __all__ = [ "HuggingFace", @@ -14,4 +16,6 @@ "mistral_hf", "openllama_hf", "stablelm_hf", + "yi_hf", + "zephyr_hf", ] diff --git a/spacy_llm/models/hf/mistral.py b/spacy_llm/models/hf/mistral.py index 3c5039a2..a45915b4 100644 --- a/spacy_llm/models/hf/mistral.py +++ b/spacy_llm/models/hf/mistral.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional from confection import SimpleFrozenDict @@ -94,13 +94,12 @@ def mistral_hf( name: Mistral.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: +) -> Mistral: """Generates Mistral instance that can execute a set of prompts and return the raw responses. - name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names(). + name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. config_run (Optional[Dict[str, Any]]): HF config for running the model. - RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return - the raw responses. + RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses. """ return Mistral( name=name, config_init=config_init, config_run=config_run, context_length=8000 diff --git a/spacy_llm/models/hf/yi.py b/spacy_llm/models/hf/yi.py new file mode 100644 index 00000000..6cb7807e --- /dev/null +++ b/spacy_llm/models/hf/yi.py @@ -0,0 +1,128 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from confection import SimpleFrozenDict + +from ...compat import Literal, transformers +from ...registry.util import registry +from .base import HuggingFace + + +class Yi(HuggingFace): + MODEL_NAMES = Literal[ # noqa: F722 + "Yi-34B", + "Yi-34B-chat-8bits", + "Yi-6B-chat", + "Yi-6B", + "Yi-6B-200K", + "Yi-34B-chat", + "Yi-34B-chat-4bits", + "Yi-34B-200K", + ] + + def __init__( + self, + name: MODEL_NAMES, + config_init: Optional[Dict[str, Any]], + config_run: Optional[Dict[str, Any]], + context_length: int, + ): + self._tokenizer: Optional["transformers.AutoTokenizer"] = None + self._is_instruct = "instruct" in name + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) + + assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) + + # Instantiate GenerationConfig object from config dict. + self._hf_config_run = transformers.GenerationConfig.from_pretrained( + self._name, **self._config_run + ) + # To avoid deprecation warning regarding usage of `max_length`. + self._hf_config_run.max_new_tokens = self._hf_config_run.max_length + + def init_model(self) -> Any: + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + self._name, use_fast=False + ) + init_cfg = self._config_init + device: Optional[str] = None + if "device" in init_cfg: + device = init_cfg.pop("device") + + model = transformers.AutoModelForCausalLM.from_pretrained( + self._name, **init_cfg, resume_download=True + ).eval() + if device: + model.to(device) + + return model + + @property + def hf_account(self) -> str: + return "01-ai" + + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] + assert hasattr(self._model, "generate") + assert hasattr(self._tokenizer, "apply_chat_template") + assert self._tokenizer + + responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + prompts_for_doc = list(prompts_for_doc) + + tokenized_input_ids = [ + self._tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": prompt}], + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + ) + for prompt in prompts_for_doc + ] + tokenized_input_ids = [ + tp.to(self._model.device) for tp in tokenized_input_ids + ] + + responses.append( + [ + self._tokenizer.decode( + self._model.generate( + input_ids=tok_ii, generation_config=self._hf_config_run + )[:, tok_ii.shape[1] :][0], + skip_special_tokens=True, + ).strip("\n") + for tok_ii in tokenized_input_ids + ] + ) + + return responses + + @staticmethod + def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: + default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() + return {**default_cfg_init, **{"torch_dtype": "auto"}}, default_cfg_run + + +@registry.llm_models("spacy.Yi.v1") +def yi_hf( + name: Yi.MODEL_NAMES, + config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), + config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), +) -> Yi: + """Generates Yi instance that can execute a set of prompts and return the raw responses. + name (Literal): Name of the Yi model. Has to be one of Yi.get_model_names(). + config_init (Optional[Dict[str, Any]]): HF config for initializing the model. + config_run (Optional[Dict[str, Any]]): HF config for running the model. + RETURNS (Yi): Yi instance that can execute a set of prompts and return the raw responses. + """ + return Yi( + name=name, + config_init=config_init, + config_run=config_run, + context_length=200000 if "200K" in name else 32000, + ) diff --git a/spacy_llm/models/hf/zephyr.py b/spacy_llm/models/hf/zephyr.py new file mode 100644 index 00000000..26d4aab5 --- /dev/null +++ b/spacy_llm/models/hf/zephyr.py @@ -0,0 +1,101 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from confection import SimpleFrozenDict + +from ...compat import Literal, transformers +from ...registry.util import registry +from .base import HuggingFace + + +class Zephyr(HuggingFace): + MODEL_NAMES = Literal["zephyr-7b-beta"] # noqa: F722 + + def __init__( + self, + name: MODEL_NAMES, + config_init: Optional[Dict[str, Any]], + config_run: Optional[Dict[str, Any]], + context_length: int, + ): + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) + + # Instantiate GenerationConfig object from config dict. + self._hf_config_run = transformers.GenerationConfig.from_pretrained( + self._name, **self._config_run + ) + # To avoid deprecation warning regarding usage of `max_length`. + self._hf_config_run.max_new_tokens = self._hf_config_run.max_length + + def init_model(self) -> Any: + return transformers.pipeline( + "text-generation", + model=self._name, + return_full_text=False, + **self._config_init + ) + + @property + def hf_account(self) -> str: + return "HuggingFaceH4" + + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] + responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + formatted_prompts_for_doc = [ + self._model.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=False, + ) + for prompt in prompts_for_doc + ] + + responses.append( + [ + self._model(prompt, generation_config=self._hf_config_run)[0][ + "generated_text" + ] + .replace("<|assistant|>", "") + .strip("\n") + for prompt in formatted_prompts_for_doc + ] + ) + + return responses + + @staticmethod + def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: + default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() + return default_cfg_init, { + **default_cfg_run, + **{ + "max_new_tokens": 256, + "do_sample": True, + "temperature": 0.7, + "top_k": 50, + "top_p": 0.95, + }, + } + + +@registry.llm_models("spacy.Zephyr.v1") +def zephyr_hf( + name: Zephyr.MODEL_NAMES, + config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), + config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), +) -> Zephyr: + """Generates Zephyr instance that can execute a set of prompts and return the raw responses. + name (Literal): Name of the Zephyr model. Has to be one of Zephyr.get_model_names(). + config_init (Optional[Dict[str, Any]]): HF config for initializing the model. + config_run (Optional[Dict[str, Any]]): HF config for running the model. + RETURNS (Zephyr): Zephyr instance that can execute a set of prompts and return the raw responses. + """ + return Zephyr( + name=name, config_init=config_init, config_run=config_run, context_length=8000 + ) diff --git a/spacy_llm/tasks/entity_linker/util.py b/spacy_llm/tasks/entity_linker/util.py index 55c44d6f..0f3399e8 100644 --- a/spacy_llm/tasks/entity_linker/util.py +++ b/spacy_llm/tasks/entity_linker/util.py @@ -206,4 +206,10 @@ def reduce_shards_to_doc(task: EntityLinkerTask, shards: Iterable[Doc]) -> Doc: RETURNS (Doc): Fused doc instance. """ # Entities are additive, so we can just merge shards. - return Doc.from_docs(list(shards), ensure_whitespace=True) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Skipping .* while merging docs.", + ) + return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/sentiment/util.py b/spacy_llm/tasks/sentiment/util.py index 4352b62c..8f01e037 100644 --- a/spacy_llm/tasks/sentiment/util.py +++ b/spacy_llm/tasks/sentiment/util.py @@ -42,7 +42,12 @@ def reduce_shards_to_doc(task: SentimentTask, shards: Iterable[Doc]) -> Doc: setattr( doc._, task.field, - sum([score * weight for score, weight in zip(sent_scores, weights)]), + sum( + [ + (score if score else 0) * weight + for score, weight in zip(sent_scores, weights) + ] + ), ) return doc diff --git a/spacy_llm/tests/models/test_dolly.py b/spacy_llm/tests/models/test_dolly.py index 6a6dc32f..d88317e9 100644 --- a/spacy_llm/tests/models/test_dolly.py +++ b/spacy_llm/tests/models/test_dolly.py @@ -27,6 +27,7 @@ [components.llm] factory = "llm" +save_io = True [components.llm.task] @llm_tasks = "spacy.NoOp.v1" diff --git a/spacy_llm/tests/models/test_yi.py b/spacy_llm/tests/models/test_yi.py new file mode 100644 index 00000000..95b73136 --- /dev/null +++ b/spacy_llm/tests/models/test_yi.py @@ -0,0 +1,71 @@ +import copy + +import pytest +import spacy +from confection import Config # type: ignore[import] +from thinc.compat import has_torch_cuda_gpu + +from ...compat import torch + +_PIPE_CFG = { + "model": { + "@llm_models": "spacy.Yi.v1", + "name": "Yi-6B-chat", + }, + "task": {"@llm_tasks": "spacy.NoOp.v1"}, +} + +_NLP_CONFIG = """ + +[nlp] +lang = "en" +pipeline = ["llm"] +batch_size = 128 + +[components] + +[components.llm] +factory = "llm" + +[components.llm.task] +@llm_tasks = "spacy.NoOp.v1" + +[components.llm.model] +@llm_models = "spacy.Yi.v1" +name = "Yi-6B" +""" + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +@pytest.mark.skip( + reason="CI runner fails with 'cutlassF: no kernel found to launch!' - to be investigated" +) +def test_init(): + """Test initialization and simple run.""" + nlp = spacy.blank("en") + cfg = copy.deepcopy(_PIPE_CFG) + nlp.add_pipe("llm", config=cfg) + nlp("This is a test.") + torch.cuda.empty_cache() + + +@pytest.mark.gpu +@pytest.mark.skip(reason="CI runner needs more GPU memory") +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_init_from_config(): + orig_config = Config().from_str(_NLP_CONFIG) + nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) + assert nlp.pipe_names == ["llm"] + torch.cuda.empty_cache() + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_invalid_model(): + orig_config = Config().from_str(_NLP_CONFIG) + config = copy.deepcopy(orig_config) + config["components"]["llm"]["model"]["name"] = "x" + with pytest.raises(ValueError, match="unexpected value; permitted"): + spacy.util.load_model_from_config(config, auto_fill=True) + torch.cuda.empty_cache() diff --git a/spacy_llm/tests/models/test_zephyr.py b/spacy_llm/tests/models/test_zephyr.py new file mode 100644 index 00000000..e026854a --- /dev/null +++ b/spacy_llm/tests/models/test_zephyr.py @@ -0,0 +1,68 @@ +import copy + +import pytest +import spacy +from confection import Config # type: ignore[import] +from thinc.compat import has_torch_cuda_gpu + +from ...compat import torch + +_PIPE_CFG = { + "model": { + "@llm_models": "spacy.Zephyr.v1", + "name": "zephyr-7b-beta", + }, + "task": {"@llm_tasks": "spacy.NoOp.v1"}, +} + +_NLP_CONFIG = """ + +[nlp] +lang = "en" +pipeline = ["llm"] +batch_size = 128 + +[components] + +[components.llm] +factory = "llm" + +[components.llm.task] +@llm_tasks = "spacy.NoOp.v1" + +[components.llm.model] +@llm_models = "spacy.Zephyr.v1" +name = "zephyr-7b-beta" +""" + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_init(): + """Test initialization and simple run.""" + nlp = spacy.blank("en") + cfg = copy.deepcopy(_PIPE_CFG) + nlp.add_pipe("llm", config=cfg) + nlp("This is a test.") + torch.cuda.empty_cache() + + +@pytest.mark.gpu +@pytest.mark.skip(reason="CI runner needs more GPU memory") +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_init_from_config(): + orig_config = Config().from_str(_NLP_CONFIG) + nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) + assert nlp.pipe_names == ["llm"] + torch.cuda.empty_cache() + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_invalid_model(): + orig_config = Config().from_str(_NLP_CONFIG) + config = copy.deepcopy(orig_config) + config["components"]["llm"]["model"]["name"] = "x" + with pytest.raises(ValueError, match="unexpected value; permitted"): + spacy.util.load_model_from_config(config, auto_fill=True) + torch.cuda.empty_cache()