pvv-chan/llm.py

229 lines
8.1 KiB
Python
Raw Normal View History

2024-05-27 18:42:18 +02:00
#!python3
import os
import shutil
import datetime
from langchain_community.chat_models import ChatOllama
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_community.document_loaders import MWDumpLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms import Ollama
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_transformers.embeddings_redundant_filter import EmbeddingsRedundantFilter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_core.prompts import HumanMessagePromptTemplate
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain.utils.html import (PREFIXES_TO_IGNORE_REGEX,
SUFFIXES_TO_IGNORE_REGEX)
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import Ollama
from enum import Enum
text_model = "llama3" #ollama
embedding_model ="all-MiniLM-L6-v2" #hugginface
wikilocation=os.environ.get("RAG_WIKI_LOCATION", "wiki/current.xml") #mediawiki xml to index
index_dir = "./index"
date = datetime.datetime.now().strftime("%Y-%m-%d")
global conversation
conversation = None
ollamaUrl = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
class Emotion(str, Enum):
NEUTRAL = "neutral"
HAPPY = "happy"
SAD = "sad"
ANGRY = "angry"
SURPRISED = "surprised"
CONFUSED = "confused"
EXCITED = "excited"
CALM = "calm"
class Action(str, Enum):
NOTHING = "nothing"
STUTTER = "stutter"
SQUEAL = "squeal"
MEWOW = "mewow"
SMUG = "smug"
WAGS_TAIL = "wags_tail"
WINK = "wink"
NOD = "nod"
LAUGH = "laugh"
SIGH = "sigh"
GLOOMY = "gloomy"
LOOK_AWAY = "look_away"
LOOK_TOWARDS_YOU = "look_towards_you"
def chat(question="Hello, how are you today?", image_description=""):
#example with rag
global conversation
global format_instructions
#TODO: implement chat history and memory, with storage to disk
chat_history = []
response = conversation({"question": question, "image_description": image_description, "chat_history": chat_history, "format_instructions": format_instructions, "date": date})
# print(response)
answer = response['answer']
result = output_parser.parse(answer)
#enforce keys.
try:
result["emotion"] = Emotion(result["emotion"])
except:
print(f"Could not parse emotion: {result['emotion']}")
result["emotion"] = Emotion.NEUTRAL
for i in range(len(result["actions"])):
try:
result["actions"][i] = Action(result["actions"][i])
except:
print(f"Could not parse action: {result['actions'][i]}")
result["actions"][i] = Action.NOTHING
# print(result)
return result
#some toy functions to interact with the llm
endstring = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>" #think the llama3 end tokens are not properly implemented, in langchain yet.
def simple_stream(prompt="A question to ask the model", temperature=0.5,):
llm = Ollama(model=text_model, temperature=temperature, base_url=ollamaUrl, stop=["<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "<|reserved_special_token|>"]) #stop needs to be manually given for llama3 for now.
return llm.stream(prompt)
def simple(prompt="A question to ask the model", temperature=0.5,):
stream = simple_stream(prompt, temperature)
result = ""
for line in stream:
result += line.rstrip("\n")
if result.endswith(endstring):
result = result.replace(endstring, "")
return result
return result
#https://scribe.rip/rahasak/build-rag-application-using-a-llm-running-on-local-computer-with-ollama-and-langchain-e6513853fda0
def init_index():
# remove the current index
if os.path.exists(index_dir):
shutil.rmtree(index_dir)
# Load data from MediaWiki dump
documents = MWDumpLoader(wikilocation).load()
#TODO: add chat history to the documents
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
documents = text_splitter.split_documents(documents)
# Apply the redundant filter
embeddings = HuggingFaceEmbeddings(
model_name=embedding_model,
multi_process=True,
# encode_kwargs={"normalize_embeddings": True},
)
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
documents = redundant_filter.transform_documents(documents)
vectordb = Chroma.from_documents(
documents=documents,
embedding=embeddings,
persist_directory=index_dir,
collection_name="pvv-wiki"
)
vectordb.persist()
response_schemas = [
ResponseSchema(
name="response",
description="reply to the user's question or statement.",
),
ResponseSchema(
name="emotion",
description=f"emotion expressed in the response, selected from a set of possible options {list(str(e.value) for e in Emotion)}",
type="Emotion",
),
ResponseSchema(
name="actions",
description=f"List of actions to take at random in response to the user's question or statement from the set {list(str(a.value) for a in Action)}",
type="List[Action]",
),
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
general_system_template = r"""
Given a specific context, please give a short answer to the question, use relevant context to try and find a possible outcome. If the data does not help, be uncertain in the final answear.
The current date is {date}
----
{context}
----
You may refer to what you can see in front of you in the description below. Any reference to the camera or image should be interpreted as "you" or "your eyes" or you can se:
{image_description}
____
Do not refer to yourself as an ai.
awoid expressions in the response.
You are a cute anime carachter named pvv chan, you like programming linux, opensource and board games.
Without refering to yourself, reply to the human talking to you.
{format_instructions}
"""
general_user_template = "Question:```{question}```"
messages = [
SystemMessagePromptTemplate.from_template(general_system_template),
HumanMessagePromptTemplate.from_template(general_user_template)
]
qa_prompt = ChatPromptTemplate.from_messages( messages )
def init_chat():
global conversation
#load index from local directory
embeddings = HuggingFaceEmbeddings(
model_name=embedding_model,
multi_process=True,
# encode_kwargs={"normalize_embeddings": True},
)
vectordb = Chroma(persist_directory=index_dir, embedding_function=embeddings)
llm = Ollama(
model=text_model,
base_url=ollamaUrl,
verbose=True,
)
# create conversation
conversation = ConversationalRetrievalChain.from_llm(
llm,
retriever=vectordb.as_retriever(search_kwargs={"k": 2} ), #amount of documents to use for the response
# retriever=vectordb.as_retriever(),
return_source_documents=True,
verbose=True,
combine_docs_chain_kwargs={"prompt": qa_prompt},
)
if __name__ == "__main__":
#print(simple(prompt="What is the meaning of life. (answear short)"))
print("inittialising index")
# init_index()
print("initialising chat")
init_chat()
print("chatting")
print(chat(question="Hello, how are you today? What is our dns server named?"))