I have a MacBook M3 Pro with 18GB memory and I’m building a retrieval-augmented generation setup using LangChain with the Llama-3.2-3B-Instruct model. My vector database is Milvus and I’m working in Jupyter notebook.
The Problem: When I call DocumentQA.from_chain_type, my notebook cell keeps running forever. I waited about 15 minutes but it never completes.
from langchain.chains import DocumentQA
qa_system = DocumentQA.from_chain_type(
llm=language_model,
retriever=document_retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": custom_prompt}
)
result = qa_system.invoke({"query": user_question})
Here’s my custom LLM wrapper:
from langchain.llms.base import LLM
from typing import List, Dict
from pydantic import PrivateAttr
class CustomHFLLM(LLM):
_model_pipeline: any = PrivateAttr()
def __init__(self, pipeline):
super().__init__()
self._model_pipeline = pipeline
def _call(self, prompt: str, stop: List[str] = None) -> str:
output = self._model_pipeline(prompt, num_return_sequences=1)
return output[0]["generated_text"]
@property
def _identifying_params(self):
return {"name": "CustomHFLLM"}
@property
def _llm_type(self):
return "custom"
language_model = CustomHFLLM(pipeline=text_pipeline)
My pipeline setup:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
model_id = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=auth_token)
text_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
truncation=True
)
Prompt configuration:
template_string = """
You are an AI assistant. Answer the question based on the provided context.
If the context doesn't contain enough information, say you don't know.
Context:
{context}
Question:
{question}
Response:
"""
custom_prompt = PromptTemplate(
input_variables=["context", "question"],
template=template_string
)
Custom retriever class:
class CustomMilvusRetriever(BaseRetriever, BaseModel):
milvus_collection: any
embed_func: Callable[[str], np.ndarray]
content_field: str
vector_field: str
result_count: int = 5
def get_relevant_documents(self, query: str) -> List[Dict]:
query_vector = self.embed_func(query)
search_config = {"metric_type": "IP", "params": {"nprobe": 10}}
search_results = self.milvus_collection.search(
data=[query_vector],
anns_field=self.vector_field,
param=search_config,
limit=self.result_count,
output_fields=[self.content_field]
)
docs = []
for match in search_results[0]:
docs.append(
Document(
page_content=match.entity.get(self.content_field),
metadata={"score": match.distance}
)
)
return docs
async def aget_relevant_documents(self, query: str) -> List[Dict]:
return self.get_relevant_documents(query)
document_retriever = CustomMilvusRetriever(
milvus_collection=my_collection,
embed_func=embedding_model.embed_query,
content_field="text",
vector_field="embedding",
result_count=5
)
I confirmed MPS acceleration works:
import torch
if torch.backends.mps.is_available():
print("MPS acceleration available")
Update: Adding verbose mode shows the chain enters successfully and formats the prompt correctly with retrieved context, but then hangs at the LLM generation step.
Any ideas what could cause this infinite loop?