こんにちは。AWS CLIが好きな福島です。
はじめに
突然ですが、さまざまな種類の大規模言語モデル(LLM)が存在することを考えると、複数のLLMに同時に質問してみたいと思ったことはないでしょうか。
そこで今回は、StreamlitとLangChainを用いて、一度に複数のLLMへの質問を実現する方法をご紹介いたします。
イメージ
動画の通り、上部にあるセレクトボックスに質問したいLLMを複数選択し、 下にあるボックスに質問することで1度に複数のLLMに質問することができます。
※gemeni-proだけストリーミングする方法が不明だったため、ストリーミングできていません...
サンプルコード
コードはフロントエンドとバックエンドで2つのファイルにまとめてます。 色々と気になる点あるかと思いますが、ご容赦ください。
フロントエンド(frontend.py)
import streamlit as st import uuid import os from backend_generate_answer import ChatApplication import time ss = st.session_state CONVERSATION_HISTORY_DB = os.environ.get("CONVERSATION_HISTORY_DB") SELECT_MODELS = [ "anthropic.claude-v2:1", "gpt-35-turbo", "gemini-pro", # "ai21.j2-ultra-v1", // ストリーミング出力にサポートしていないためコメントアウト # "meta.llama2-13b-chat-v1" "meta.llama2-70b-chat-v1", # "cohere.command-light-text-v14", "cohere.command-text-v14", # "amazon.titan-text-lite-v1" "amazon.titan-text-express-v1", ] def get_or_init_chat_msgs(model): """ Streamlitのセッションにchat_msgsが存在しない場合、chat_msgsをリストとして定義する。 既に存在する場合は、chat_msgsの値を返す。 """ model_chat_msgs = f'{model}_chat_msgs' if model_chat_msgs not in ss: ss[model_chat_msgs] = [] return ss[model_chat_msgs] def display_chat_msgs(chat_msgs, display_chat_msgs_ph): """会話履歴を表示する関数""" for entry in chat_msgs: if entry["user"] == "You": display_chat_msgs_ph.chat_message("user").write(entry["message"]) elif entry["user"] == "Bot": display_chat_msgs_ph.chat_message( "assistant").write(entry["message"]) else: return None def get_or_init_session_id(model): """ Streamlitのセッションにsession_idが存在しない場合、sessiond_idにUUIDを割り当てる。 """ model_session_id = f'{model}_session_id' if model_session_id not in ss: ss[model_session_id] = f'{model}_{str(uuid.uuid4())}' return ss[model_session_id] def get_or_init_app(model): """ Streamlitのセッションにappが存在しない場合、ChatApplicationをインスタンス化してappに割り当てる。 既に存在する場合は、appの値を返す。 """ config = { "session_id": get_or_init_session_id(model), "conversation_history_db": CONVERSATION_HISTORY_DB, "user_messages_ph": ss['user_messages_ph'], "bot_messages_ph": ss['bot_messages_ph'] } model_app = f'{model}_app' if model_app not in ss: ss[model_app] = ChatApplication(config) return ss[model_app] def record_chat_msgs(model, user_messages, bot_messages): """会話履歴を保存する関数""" chat_msgs = get_or_init_chat_msgs(model) chat_msgs.append( { "user": "You", "message": user_messages } ) chat_msgs.append( { "user": "Bot", "message": bot_messages } ) return chat_msgs def chatbot_handler(): """ユーザーがメッセージを送信した際に動くハンドラー関数""" user_messages = ss.user_messages for model in ss['select_models']: app = get_or_init_app(model) start_time = time.time() answer = app.generate_answer(user_messages, model)['answer'] end_time = time.time() elapsed_time = end_time - start_time bot_messages = f'{answer}\n\n生成時間: {round(elapsed_time,1)}秒' ss.chat_msgs = record_chat_msgs( model, user_messages, bot_messages ) def main(): """メイン関数""" st.set_page_config( page_title="Multiple LLM Chatbot", layout="wide", initial_sidebar_state="auto", ) st.title('Multiple LLM Chatbot') ss['select_models'] = st.multiselect( label="モデルを選択してください", options=SELECT_MODELS, default=SELECT_MODELS[0], ) if ss['select_models']: columns = st.columns(len(ss['select_models'])) for i, col in enumerate(columns): with col: model = ss['select_models'][i] st.text(model) display_chat_msgs(get_or_init_chat_msgs(model), st.container()) ss['user_messages_ph'] = st.empty() ss['bot_messages_ph'] = st.empty() st.chat_input( placeholder="メッセージを入力してください", key="user_messages", on_submit=chatbot_handler, ) if __name__ == "__main__": main()
バックエンド(backend.py)
import boto3 from langchain.chains import ConversationChain from langchain.llms import Bedrock from langchain.callbacks.base import BaseCallbackHandler from langchain_google_genai import ChatGoogleGenerativeAI from langchain.memory import ConversationBufferWindowMemory from langchain.memory.chat_message_histories import DynamoDBChatMessageHistory from langchain.chat_models import AzureChatOpenAI class StreamHandler(BaseCallbackHandler): """回答をストリーミングするクラス""" def __init__(self, user_messages_ph, bot_messages_ph, user_messages, model_id): self.user_messages_ph = user_messages_ph self.bot_messages_ph = bot_messages_ph self.user_messages = user_messages self.generate_answer = "" self.model_id = model_id def on_llm_start(self, serialized, prompts, **kwargs): self.user_messages_ph.chat_message("user").write( f'To: {self.model_id} \n\n {self.user_messages} ' ) def on_llm_new_token(self, token: str, **kwargs): self.generate_answer += token self.bot_messages_ph.chat_message( "assistant").write(self.generate_answer) class ConversationHistoryMemory(): """会話を保存するメモリを初期化するクラス""" def __init__(self, conversation_history_db, session_id): self.conversation_history_db = conversation_history_db self.session_id = session_id def init_memory(self, memory_key, input_key, output_key=None): message_history = DynamoDBChatMessageHistory( table_name=self.conversation_history_db, session_id=self.session_id ) memory = ConversationBufferWindowMemory( k=5, memory_key=memory_key, chat_memory=message_history, input_key=input_key, output_key=output_key, return_messages=True, ) return memory class GenerateAnswerLLM(): """回答を生成するLLMを初期化するクラス""" def __init__(self, user_messages_ph, bot_messages_ph, user_messages, model_id): self.model_id = model_id self.callback_handler = StreamHandler( user_messages_ph, bot_messages_ph, user_messages, self.model_id ) def init_generate_llm(self): bedrock_runtime = boto3.client(service_name="bedrock-runtime") generate_llm = Bedrock( model_id=self.model_id, client=bedrock_runtime, # model_kwargs={"max_tokens_to_sample": 4096}, // モデルごとにオプションが異なるためコメントアウト streaming=True, callbacks=[self.callback_handler] ) return generate_llm def init_generate_google_llm(self): generate_llm = ChatGoogleGenerativeAI( model=self.model_id, streaming=True, callbacks=[self.callback_handler] ) return generate_llm def init_generate_azure_llm(self): generate_llm = AzureChatOpenAI( azure_deployment=self.model_id, openai_api_version="2023-05-15", streaming=True, callbacks=[self.callback_handler] ) return generate_llm class ChatApplication(): def __init__(self, config): self.user_messages_ph = config['user_messages_ph'] self.bot_messages_ph = config['bot_messages_ph'] self.memory_instance = ConversationHistoryMemory( config['conversation_history_db'], config['session_id'] ) self.memory = self.memory_instance.init_memory( memory_key='history', input_key='input' ) def generate_answer(self, user_messages, model_id): generate_llm_instance = GenerateAnswerLLM( self.user_messages_ph, self.bot_messages_ph, user_messages, model_id ) if model_id == 'gemini-pro': generate_llm = generate_llm_instance.init_generate_google_llm() elif model_id == 'gpt-35-turbo': generate_llm = generate_llm_instance.init_generate_azure_llm() else: generate_llm = generate_llm_instance.init_generate_llm() generate_llm_chain = ConversationChain( llm=generate_llm, memory=self.memory ) response = generate_llm_chain.predict(input=user_messages) return {"answer": response}
実行方法
上記二つのファイルを同一のフォルダに配置します。
次に会話履歴を保存するようにSessionIdというキーを持ったDynamoDBを構築し、環境変数に定義します。
export CONVERSATION_HISTORY_DB=【DynamoDBのテーブル名】
また、GoogleやOpenAIに関する情報を以下のように環境変数に定義します。
export GOOGLE_API_KEY=【GoogleのAPIキー】 export OPENAI_API_KEY=【OpenAIのAPIキー】 export AZURE_OPENAI_ENDPOINT=【Azure OpenAIのエンドポイント】
その後、以下のコマンドを実行すれば、ブラウザが立ち上がるかと思います。
streamlit run frondend.py
ポイント
ストリーミング出力について
ストリーミング出力のポイントとなるコードは以下の通りです。
まずは、フロントエンド側でメッセージを出力する要素(user_messages_ph,bot_messages_ph)を定義しておきます。
## フロントエンド(frontend.py) ss['user_messages_ph'] = st.empty() ss['bot_messages_ph'] = st.empty()
次にその要素(user_messages_ph,bot_messages_ph)をChatApplicationクラスをインスタンス化する際に渡します。
## フロントエンド(frontend.py) config = { "session_id": get_or_init_session_id(model), "conversation_history_db": CONVERSATION_HISTORY_DB, "user_messages_ph": ss['user_messages_ph'], "bot_messages_ph": ss['bot_messages_ph'] } model_app = f'{model}_app' if model_app not in ss: ss[model_app] = ChatApplication(config)
最終的にフロントエンドから渡った要素(user_messages_ph,bot_messages_ph)は、StreamHandlerクラスに渡ります。
StreamHandlerクラスでは、LLMが回答を開始するタイミングでユーザーからの質問をuser_messages_phに表示し、LLMがストリーミングで回答を生成している際に、bot_messages_phにメッセージをストリーミング出力しています。
## バックエンド(backend.py) class StreamHandler(BaseCallbackHandler): """回答をストリーミングするクラス""" def __init__(self, user_messages_ph, bot_messages_ph, user_messages, model_id): self.user_messages_ph = user_messages_ph self.bot_messages_ph = bot_messages_ph self.user_messages = user_messages self.generate_answer = "" self.model_id = model_id def on_llm_start(self, serialized, prompts, **kwargs): self.user_messages_ph.chat_message("user").write( f'To: {self.model_id} \n\n {self.user_messages} ' ) def on_llm_new_token(self, token: str, **kwargs): self.generate_answer += token self.bot_messages_ph.chat_message( "assistant").write(self.generate_answer)
※補足 StreamlitCallbackHandlerというクラスが用意されているため、Streamlitを使ってストリーミング出力する場合は、こちらを使う方が簡単に実装できるかもしれません。
https://python.langchain.com/docs/integrations/callbacks/streamlit
会話履歴の表示
まず、質問と生成された回答は、Streamlitのセッション保持機能を使いつつ、リストの辞書型のchat_msgsに保存しています。
## フロントエンド(frontend.py) def record_chat_msgs(model, user_messages, bot_messages): """会話履歴を保存する関数""" chat_msgs = get_or_init_chat_msgs(model) chat_msgs.append( { "user": "You", "message": user_messages } ) chat_msgs.append( { "user": "Bot", "message": bot_messages } ) return chat_msgs
そして以下の関数により、保存されているメッセージを画面に描画されます。
## フロントエンド(frontend.py) def display_chat_msgs(chat_msgs, display_chat_msgs_ph): """会話履歴を表示する関数""" for entry in chat_msgs: if entry["user"] == "You": display_chat_msgs_ph.chat_message("user").write(entry["message"]) elif entry["user"] == "Bot": display_chat_msgs_ph.chat_message( "assistant").write(entry["message"]) else: return None
選択したモデルに応じた動的なカラムの用意
以下のコードによって選択されたLLMの数に応じて、st.columnsを用意します。 st.columns(LLM)ごとの出力は、for文内のコードによって実現されています。
## フロントエンド(frontend.py) if ss['select_models']: columns = st.columns(len(ss['select_models'])) for i, col in enumerate(columns): with col: model = ss['select_models'][i] st.text(model) display_chat_msgs(get_or_init_chat_msgs(model), st.container())
補足
モデルを追加したい場合
SELECT_MODELSの設定を変えることで自由に変更ができます。
SELECT_MODELS = [ "anthropic.claude-v2:1", "gpt-35-turbo", "gemini-pro", # "ai21.j2-ultra-v1", # "meta.llama2-13b-chat-v1" "meta.llama2-70b-chat-v1", # "cohere.command-light-text-v14", "cohere.command-text-v14", # "amazon.titan-text-lite-v1" "amazon.titan-text-express-v1", ]
モデルのデフォルトの数を変更したい場合
★の箇所を変更することで変更できます。
## フロントエンド(frontend.py) ss['select_models'] = st.multiselect( label="モデルを選択してください", options=SELECT_MODELS, default=SELECT_MODELS[0], ★ )
終わりに
今回は興味本位でStreamlitとLangChainを活用して、1度に複数のLLMに質問してみました。 1度に複数のLLMに質問してみたいと思っている方の参考になれば幸いです。