Skip to content

Commit 117f689

Browse files
authored
Revert "Add registry functions to instantiate models by provider (#428)" (#469)
This reverts commit ff07682.
1 parent 5b49105 commit 117f689

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+97
-363
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ factory = "llm"
119119
labels = ["COMPLIMENT", "INSULT"]
120120

121121
[components.llm.model]
122-
@llm_models = "spacy.OpenAI.v1"
123-
name = "gpt-4"
122+
@llm_models = "spacy.GPT-4.v2"
124123
```
125124

126125
Now run:

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ filterwarnings = [
2727
"ignore:^.*The `construct` method is deprecated.*",
2828
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
2929
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*",
30-
"ignore:^.*was deprecated in langchain-community.*",
31-
"ignore:^.*was deprecated in LangChain 0.0.1.*",
32-
"ignore:^.*the load_module() method is deprecated and slated for removal in Python 3.12.*"
30+
"ignore:^.*was deprecated in langchain-community.*"
3331
]
3432
markers = [
3533
"external: interacts with a (potentially cost-incurring) third-party API",

requirements-dev.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ langchain>=0.1,<0.2; python_version>="3.9"
1313
openai>=0.27,<=0.28.1; python_version>="3.9"
1414

1515
# Necessary for running all local models on GPU.
16-
# TODO: transformers > 4.38 causes bug in model handling due to unknown factors. To be investigated.
17-
transformers[sentencepiece]>=4.0.0,<=4.38
16+
transformers[sentencepiece]>=4.0.0
1817
torch
1918
einops>=0.4
2019

spacy_llm/models/hf/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44
from .llama2 import llama2_hf
55
from .mistral import mistral_hf
66
from .openllama import openllama_hf
7-
from .registry import huggingface_v1
87
from .stablelm import stablelm_hf
98

109
__all__ = [
1110
"HuggingFace",
1211
"dolly_hf",
1312
"falcon_hf",
14-
"huggingface_v1",
1513
"llama2_hf",
1614
"mistral_hf",
1715
"openllama_hf",

spacy_llm/models/hf/mistral.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def mistral_hf(
9999
name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names().
100100
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
101101
config_run (Optional[Dict[str, Any]]): HF config for running the model.
102-
RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses.
102+
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Mistral instance that can execute a set of prompts and return
103+
the raw responses.
103104
"""
104105
return Mistral(
105106
name=name, config_init=config_init, config_run=config_run, context_length=8000

spacy_llm/models/hf/registry.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

spacy_llm/models/langchain/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def query_langchain(
9999
prompts (Iterable[Iterable[Any]]): Prompts to execute.
100100
RETURNS (Iterable[Iterable[Any]]): LLM responses.
101101
"""
102-
assert callable(model)
103102
return [
104103
[model.invoke(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts
105104
]

spacy_llm/models/rest/anthropic/registry.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,6 @@
77
from .model import Anthropic, Endpoints
88

99

10-
@registry.llm_models("spacy.Anthropic.v1")
11-
def anthropic_v1(
12-
name: str,
13-
config: Dict[Any, Any] = SimpleFrozenDict(),
14-
strict: bool = Anthropic.DEFAULT_STRICT,
15-
max_tries: int = Anthropic.DEFAULT_MAX_TRIES,
16-
interval: float = Anthropic.DEFAULT_INTERVAL,
17-
max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME,
18-
context_length: Optional[int] = None,
19-
) -> Anthropic:
20-
"""Returns Anthropic model instance using REST to prompt API.
21-
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
22-
name (str): Name of model to use.
23-
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
24-
or other response object that does not conform to the expectation of how a well-formed response object from
25-
this API should look like). If False, the API error responses are returned by __call__(), but no error will
26-
be raised.
27-
max_tries (int): Max. number of tries for API request.
28-
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
29-
at each retry.
30-
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
31-
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
32-
natively provided by spacy-llm.
33-
RETURNS (Anthropic): Instance of Anthropic model.
34-
"""
35-
return Anthropic(
36-
name=name,
37-
endpoint=Endpoints.COMPLETIONS.value,
38-
config=config,
39-
strict=strict,
40-
max_tries=max_tries,
41-
interval=interval,
42-
max_request_time=max_request_time,
43-
context_length=context_length,
44-
)
45-
46-
4710
@registry.llm_models("spacy.Claude-2.v2")
4811
def anthropic_claude_2_v2(
4912
config: Dict[Any, Any] = SimpleFrozenDict(),

spacy_llm/models/rest/cohere/registry.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,6 @@
77
from .model import Cohere, Endpoints
88

99

10-
@registry.llm_models("spacy.Cohere.v1")
11-
def cohere_v1(
12-
name: str,
13-
config: Dict[Any, Any] = SimpleFrozenDict(),
14-
strict: bool = Cohere.DEFAULT_STRICT,
15-
max_tries: int = Cohere.DEFAULT_MAX_TRIES,
16-
interval: float = Cohere.DEFAULT_INTERVAL,
17-
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
18-
context_length: Optional[int] = None,
19-
) -> Cohere:
20-
"""Returns Cohere model instance using REST to prompt API.
21-
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
22-
name (str): Name of model to use.
23-
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
24-
or other response object that does not conform to the expectation of how a well-formed response object from
25-
this API should look like). If False, the API error responses are returned by __call__(), but no error will
26-
be raised.
27-
max_tries (int): Max. number of tries for API request.
28-
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
29-
at each retry.
30-
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
31-
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
32-
natively provided by spacy-llm.
33-
RETURNS (Cohere): Instance of Cohere model.
34-
"""
35-
return Cohere(
36-
name=name,
37-
endpoint=Endpoints.COMPLETION.value,
38-
config=config,
39-
strict=strict,
40-
max_tries=max_tries,
41-
interval=interval,
42-
max_request_time=max_request_time,
43-
context_length=context_length,
44-
)
45-
46-
4710
@registry.llm_models("spacy.Command.v2")
4811
def cohere_command_v2(
4912
config: Dict[Any, Any] = SimpleFrozenDict(),
@@ -93,7 +56,7 @@ def cohere_command(
9356
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
9457
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
9558
"""Returns Cohere instance for 'command' model using REST to prompt API.
96-
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Name of model to use.
59+
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use.
9760
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
9861
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
9962
or other response object that does not conform to the expectation of how a well-formed response object from

spacy_llm/models/rest/openai/registry.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,47 +8,6 @@
88

99
_DEFAULT_TEMPERATURE = 0.0
1010

11-
12-
@registry.llm_models("spacy.OpenAI.v1")
13-
def openai_v1(
14-
name: str,
15-
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
16-
strict: bool = OpenAI.DEFAULT_STRICT,
17-
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
18-
interval: float = OpenAI.DEFAULT_INTERVAL,
19-
max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME,
20-
endpoint: Optional[str] = None,
21-
context_length: Optional[int] = None,
22-
) -> OpenAI:
23-
"""Returns OpenAI model instance using REST to prompt API.
24-
25-
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
26-
name (str): Model name to use. Can be any model name supported by the OpenAI API.
27-
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
28-
or other response object that does not conform to the expectation of how a well-formed response object from
29-
this API should look like). If False, the API error responses are returned by __call__(), but no error will
30-
be raised.
31-
max_tries (int): Max. number of tries for API request.
32-
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
33-
at each retry.
34-
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
35-
endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint.
36-
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
37-
natively provided by spacy-llm.
38-
RETURNS (OpenAI): OpenAI model instance.
39-
"""
40-
return OpenAI(
41-
name=name,
42-
endpoint=endpoint or Endpoints.CHAT.value,
43-
config=config,
44-
strict=strict,
45-
max_tries=max_tries,
46-
interval=interval,
47-
max_request_time=max_request_time,
48-
context_length=context_length,
49-
)
50-
51-
5211
"""
5312
Parameter explanations:
5413
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON

0 commit comments

Comments
 (0)