|
7 | 7 | from .model import Cohere, Endpoints |
8 | 8 |
|
9 | 9 |
|
10 | | -@registry.llm_models("spacy.Cohere.v1") |
11 | | -def cohere_v1( |
12 | | - name: str, |
13 | | - config: Dict[Any, Any] = SimpleFrozenDict(), |
14 | | - strict: bool = Cohere.DEFAULT_STRICT, |
15 | | - max_tries: int = Cohere.DEFAULT_MAX_TRIES, |
16 | | - interval: float = Cohere.DEFAULT_INTERVAL, |
17 | | - max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME, |
18 | | - context_length: Optional[int] = None, |
19 | | -) -> Cohere: |
20 | | - """Returns Cohere model instance using REST to prompt API. |
21 | | - config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. |
22 | | - name (str): Name of model to use. |
23 | | - strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON |
24 | | - or other response object that does not conform to the expectation of how a well-formed response object from |
25 | | - this API should look like). If False, the API error responses are returned by __call__(), but no error will |
26 | | - be raised. |
27 | | - max_tries (int): Max. number of tries for API request. |
28 | | - interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff |
29 | | - at each retry. |
30 | | - max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. |
31 | | - context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length |
32 | | - natively provided by spacy-llm. |
33 | | - RETURNS (Cohere): Instance of Cohere model. |
34 | | - """ |
35 | | - return Cohere( |
36 | | - name=name, |
37 | | - endpoint=Endpoints.COMPLETION.value, |
38 | | - config=config, |
39 | | - strict=strict, |
40 | | - max_tries=max_tries, |
41 | | - interval=interval, |
42 | | - max_request_time=max_request_time, |
43 | | - context_length=context_length, |
44 | | - ) |
45 | | - |
46 | | - |
47 | 10 | @registry.llm_models("spacy.Command.v2") |
48 | 11 | def cohere_command_v2( |
49 | 12 | config: Dict[Any, Any] = SimpleFrozenDict(), |
@@ -93,7 +56,7 @@ def cohere_command( |
93 | 56 | max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME, |
94 | 57 | ) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: |
95 | 58 | """Returns Cohere instance for 'command' model using REST to prompt API. |
96 | | - name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Name of model to use. |
| 59 | + name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use. |
97 | 60 | config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. |
98 | 61 | strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON |
99 | 62 | or other response object that does not conform to the expectation of how a well-formed response object from |
|
0 commit comments