229 lines
8.1 KiB
Python
229 lines
8.1 KiB
Python
|
#!python3
|
||
|
import os
|
||
|
import shutil
|
||
|
import datetime
|
||
|
from langchain_community.chat_models import ChatOllama
|
||
|
from langchain_core.messages import HumanMessage
|
||
|
from langchain_core.output_parsers import StrOutputParser
|
||
|
from langchain_community.document_loaders import MWDumpLoader
|
||
|
from langchain_core.prompts import ChatPromptTemplate
|
||
|
|
||
|
from langchain_community.llms import Ollama
|
||
|
from langchain_community.document_loaders import DirectoryLoader
|
||
|
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
|
||
|
from langchain.text_splitter import CharacterTextSplitter
|
||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
|
from langchain_community.document_transformers.embeddings_redundant_filter import EmbeddingsRedundantFilter
|
||
|
from langchain_community.vectorstores import Chroma
|
||
|
from langchain.chains import ConversationalRetrievalChain
|
||
|
from langchain_core.prompts import HumanMessagePromptTemplate
|
||
|
from langchain_core.prompts import SystemMessagePromptTemplate
|
||
|
from langchain.memory import ConversationBufferMemory
|
||
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||
|
from langchain_community.embeddings import OllamaEmbeddings
|
||
|
from langchain.utils.html import (PREFIXES_TO_IGNORE_REGEX,
|
||
|
SUFFIXES_TO_IGNORE_REGEX)
|
||
|
|
||
|
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
||
|
from langchain_core.prompts import PromptTemplate
|
||
|
from langchain_community.llms import Ollama
|
||
|
from enum import Enum
|
||
|
|
||
|
text_model = "llama3" #ollama
|
||
|
embedding_model ="all-MiniLM-L6-v2" #hugginface
|
||
|
wikilocation=os.environ.get("RAG_WIKI_LOCATION", "wiki/current.xml") #mediawiki xml to index
|
||
|
index_dir = "./index"
|
||
|
date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||
|
|
||
|
global conversation
|
||
|
conversation = None
|
||
|
|
||
|
|
||
|
ollamaUrl = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
||
|
class Emotion(str, Enum):
|
||
|
NEUTRAL = "neutral"
|
||
|
HAPPY = "happy"
|
||
|
SAD = "sad"
|
||
|
ANGRY = "angry"
|
||
|
SURPRISED = "surprised"
|
||
|
CONFUSED = "confused"
|
||
|
EXCITED = "excited"
|
||
|
CALM = "calm"
|
||
|
|
||
|
class Action(str, Enum):
|
||
|
NOTHING = "nothing"
|
||
|
STUTTER = "stutter"
|
||
|
SQUEAL = "squeal"
|
||
|
MEWOW = "mewow"
|
||
|
SMUG = "smug"
|
||
|
WAGS_TAIL = "wags_tail"
|
||
|
WINK = "wink"
|
||
|
NOD = "nod"
|
||
|
LAUGH = "laugh"
|
||
|
SIGH = "sigh"
|
||
|
GLOOMY = "gloomy"
|
||
|
LOOK_AWAY = "look_away"
|
||
|
LOOK_TOWARDS_YOU = "look_towards_you"
|
||
|
|
||
|
def chat(question="Hello, how are you today?", image_description=""):
|
||
|
#example with rag
|
||
|
global conversation
|
||
|
global format_instructions
|
||
|
|
||
|
#TODO: implement chat history and memory, with storage to disk
|
||
|
chat_history = []
|
||
|
|
||
|
response = conversation({"question": question, "image_description": image_description, "chat_history": chat_history, "format_instructions": format_instructions, "date": date})
|
||
|
# print(response)
|
||
|
answer = response['answer']
|
||
|
|
||
|
result = output_parser.parse(answer)
|
||
|
|
||
|
#enforce keys.
|
||
|
try:
|
||
|
result["emotion"] = Emotion(result["emotion"])
|
||
|
except:
|
||
|
print(f"Could not parse emotion: {result['emotion']}")
|
||
|
result["emotion"] = Emotion.NEUTRAL
|
||
|
|
||
|
for i in range(len(result["actions"])):
|
||
|
try:
|
||
|
result["actions"][i] = Action(result["actions"][i])
|
||
|
except:
|
||
|
print(f"Could not parse action: {result['actions'][i]}")
|
||
|
result["actions"][i] = Action.NOTHING
|
||
|
|
||
|
# print(result)
|
||
|
|
||
|
return result
|
||
|
|
||
|
#some toy functions to interact with the llm
|
||
|
endstring = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>" #think the llama3 end tokens are not properly implemented, in langchain yet.
|
||
|
def simple_stream(prompt="A question to ask the model", temperature=0.5,):
|
||
|
llm = Ollama(model=text_model, temperature=temperature, base_url=ollamaUrl, stop=["<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "<|reserved_special_token|>"]) #stop needs to be manually given for llama3 for now.
|
||
|
return llm.stream(prompt)
|
||
|
|
||
|
def simple(prompt="A question to ask the model", temperature=0.5,):
|
||
|
stream = simple_stream(prompt, temperature)
|
||
|
result = ""
|
||
|
for line in stream:
|
||
|
result += line.rstrip("\n")
|
||
|
if result.endswith(endstring):
|
||
|
result = result.replace(endstring, "")
|
||
|
return result
|
||
|
return result
|
||
|
|
||
|
|
||
|
#https://scribe.rip/rahasak/build-rag-application-using-a-llm-running-on-local-computer-with-ollama-and-langchain-e6513853fda0
|
||
|
def init_index():
|
||
|
# remove the current index
|
||
|
if os.path.exists(index_dir):
|
||
|
shutil.rmtree(index_dir)
|
||
|
|
||
|
|
||
|
# Load data from MediaWiki dump
|
||
|
documents = MWDumpLoader(wikilocation).load()
|
||
|
#TODO: add chat history to the documents
|
||
|
|
||
|
# Split text
|
||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||
|
documents = text_splitter.split_documents(documents)
|
||
|
|
||
|
# Apply the redundant filter
|
||
|
embeddings = HuggingFaceEmbeddings(
|
||
|
model_name=embedding_model,
|
||
|
multi_process=True,
|
||
|
# encode_kwargs={"normalize_embeddings": True},
|
||
|
)
|
||
|
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||
|
documents = redundant_filter.transform_documents(documents)
|
||
|
vectordb = Chroma.from_documents(
|
||
|
documents=documents,
|
||
|
embedding=embeddings,
|
||
|
persist_directory=index_dir,
|
||
|
collection_name="pvv-wiki"
|
||
|
)
|
||
|
vectordb.persist()
|
||
|
|
||
|
|
||
|
response_schemas = [
|
||
|
ResponseSchema(
|
||
|
name="response",
|
||
|
description="reply to the user's question or statement.",
|
||
|
),
|
||
|
ResponseSchema(
|
||
|
name="emotion",
|
||
|
description=f"emotion expressed in the response, selected from a set of possible options {list(str(e.value) for e in Emotion)}",
|
||
|
type="Emotion",
|
||
|
),
|
||
|
ResponseSchema(
|
||
|
name="actions",
|
||
|
description=f"List of actions to take at random in response to the user's question or statement from the set {list(str(a.value) for a in Action)}",
|
||
|
type="List[Action]",
|
||
|
),
|
||
|
]
|
||
|
|
||
|
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
|
||
|
format_instructions = output_parser.get_format_instructions()
|
||
|
general_system_template = r"""
|
||
|
Given a specific context, please give a short answer to the question, use relevant context to try and find a possible outcome. If the data does not help, be uncertain in the final answear.
|
||
|
The current date is {date}
|
||
|
----
|
||
|
{context}
|
||
|
----
|
||
|
You may refer to what you can see in front of you in the description below. Any reference to the camera or image should be interpreted as "you" or "your eyes" or you can se:
|
||
|
{image_description}
|
||
|
____
|
||
|
Do not refer to yourself as an ai.
|
||
|
awoid expressions in the response.
|
||
|
|
||
|
You are a cute anime carachter named pvv chan, you like programming linux, opensource and board games.
|
||
|
Without refering to yourself, reply to the human talking to you.
|
||
|
{format_instructions}
|
||
|
"""
|
||
|
general_user_template = "Question:```{question}```"
|
||
|
messages = [
|
||
|
SystemMessagePromptTemplate.from_template(general_system_template),
|
||
|
HumanMessagePromptTemplate.from_template(general_user_template)
|
||
|
]
|
||
|
qa_prompt = ChatPromptTemplate.from_messages( messages )
|
||
|
|
||
|
def init_chat():
|
||
|
global conversation
|
||
|
|
||
|
#load index from local directory
|
||
|
embeddings = HuggingFaceEmbeddings(
|
||
|
model_name=embedding_model,
|
||
|
multi_process=True,
|
||
|
# encode_kwargs={"normalize_embeddings": True},
|
||
|
)
|
||
|
vectordb = Chroma(persist_directory=index_dir, embedding_function=embeddings)
|
||
|
|
||
|
llm = Ollama(
|
||
|
model=text_model,
|
||
|
base_url=ollamaUrl,
|
||
|
verbose=True,
|
||
|
)
|
||
|
|
||
|
# create conversation
|
||
|
conversation = ConversationalRetrievalChain.from_llm(
|
||
|
llm,
|
||
|
retriever=vectordb.as_retriever(search_kwargs={"k": 2} ), #amount of documents to use for the response
|
||
|
# retriever=vectordb.as_retriever(),
|
||
|
return_source_documents=True,
|
||
|
verbose=True,
|
||
|
combine_docs_chain_kwargs={"prompt": qa_prompt},
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
#print(simple(prompt="What is the meaning of life. (answear short)"))
|
||
|
|
||
|
print("inittialising index")
|
||
|
# init_index()
|
||
|
print("initialising chat")
|
||
|
init_chat()
|
||
|
|
||
|
print("chatting")
|
||
|
print(chat(question="Hello, how are you today? What is our dns server named?"))
|