Skip to content

Commit 555f97b

Browse files
mo374zbaskaryan
andauthored
community[patch]: fix model initialization bug for deepinfra (#25727)
### Description adds an init method to ChatDeepInfra to set the model_name attribute accordings to the argument ### Issue currently, the model_name specified by the user during initialization of the ChatDeepInfra class is never set. Therefore, it always chooses the default model (meta-llama/Llama-2-70b-chat-hf, however probably since this is deprecated it always uses meta-llama/Llama-3-70b-Instruct). We stumbled across this issue and fixed it as proposed in this pull request. Feel free to change the fix according to your coding guidelines and style, this is just a proposal and we want to draw attention to this problem. ### Dependencies no additional dependencies required Feel free to contact me or @timo282 and @finitearth if you have any questions. --------- Co-authored-by: Bagatur <[email protected]> Co-authored-by: Bagatur <[email protected]>
1 parent a052173 commit 555f97b

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

libs/community/langchain_community/chat_models/deepinfra.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ class ChatDeepInfra(BaseChatModel):
222222
streaming: bool = False
223223
max_retries: int = 1
224224

225+
class Config:
226+
"""Configuration for this pydantic object."""
227+
228+
allow_population_by_field_name = True
229+
225230
@property
226231
def _default_params(self) -> Dict[str, Any]:
227232
"""Get the default parameters for calling OpenAI API."""
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from langchain_community.chat_models import ChatDeepInfra
2+
3+
4+
def test_deepinfra_model_name_param() -> None:
5+
llm = ChatDeepInfra(model_name="foo") # type: ignore[call-arg]
6+
assert llm.model_name == "foo"
7+
8+
9+
def test_deepinfra_model_param() -> None:
10+
llm = ChatDeepInfra(model="foo")
11+
assert llm.model_name == "foo"

0 commit comments

Comments
 (0)