【Streamlit/LangChain】1度に複数のLLMに質問してみる(Amazon Bedrock, Google Gemini, Azure OpenAI)

記事タイトルとURLをコピーする

こんにちは。AWS CLIが好きな福島です。

はじめに

突然ですが、さまざまな種類の大規模言語モデル(LLM)が存在することを考えると、複数のLLMに同時に質問してみたいと思ったことはないでしょうか。

そこで今回は、StreamlitとLangChainを用いて、一度に複数のLLMへの質問を実現する方法をご紹介いたします。

イメージ

動画の通り、上部にあるセレクトボックスに質問したいLLMを複数選択し、 下にあるボックスに質問することで1度に複数のLLMに質問することができます。

※gemeni-proだけストリーミングする方法が不明だったため、ストリーミングできていません...

サンプルコード

コードはフロントエンドとバックエンドで2つのファイルにまとめてます。 色々と気になる点あるかと思いますが、ご容赦ください。

フロントエンド(frontend.py)

import streamlit as st
import uuid
import os
from backend_generate_answer import ChatApplication
import time

ss = st.session_state

CONVERSATION_HISTORY_DB = os.environ.get("CONVERSATION_HISTORY_DB")
SELECT_MODELS = [
    "anthropic.claude-v2:1",
    "gpt-35-turbo",
    "gemini-pro",
    # "ai21.j2-ultra-v1", // ストリーミング出力にサポートしていないためコメントアウト
    # "meta.llama2-13b-chat-v1"
    "meta.llama2-70b-chat-v1",
    # "cohere.command-light-text-v14",
    "cohere.command-text-v14",
    # "amazon.titan-text-lite-v1"
    "amazon.titan-text-express-v1",
]


def get_or_init_chat_msgs(model):
    """
    Streamlitのセッションにchat_msgsが存在しない場合、chat_msgsをリストとして定義する。
    既に存在する場合は、chat_msgsの値を返す。
    """
    model_chat_msgs = f'{model}_chat_msgs'
    if model_chat_msgs not in ss:
        ss[model_chat_msgs] = []

    return ss[model_chat_msgs]


def display_chat_msgs(chat_msgs, display_chat_msgs_ph):
    """会話履歴を表示する関数"""
    for entry in chat_msgs:
        if entry["user"] == "You":
            display_chat_msgs_ph.chat_message("user").write(entry["message"])
        elif entry["user"] == "Bot":
            display_chat_msgs_ph.chat_message(
                "assistant").write(entry["message"])
        else:
            return None


def get_or_init_session_id(model):
    """
    Streamlitのセッションにsession_idが存在しない場合、sessiond_idにUUIDを割り当てる。
    """
    model_session_id = f'{model}_session_id'
    if model_session_id not in ss:
        ss[model_session_id] = f'{model}_{str(uuid.uuid4())}'

    return ss[model_session_id]


def get_or_init_app(model):
    """
    Streamlitのセッションにappが存在しない場合、ChatApplicationをインスタンス化してappに割り当てる。
    既に存在する場合は、appの値を返す。
    """

    config = {
        "session_id": get_or_init_session_id(model),
        "conversation_history_db": CONVERSATION_HISTORY_DB,
        "user_messages_ph": ss['user_messages_ph'],
        "bot_messages_ph": ss['bot_messages_ph']
    }

    model_app = f'{model}_app'
    if model_app not in ss:
        ss[model_app] = ChatApplication(config)

    return ss[model_app]


def record_chat_msgs(model, user_messages, bot_messages):
    """会話履歴を保存する関数"""
    chat_msgs = get_or_init_chat_msgs(model)
    chat_msgs.append(
        {
            "user": "You",
            "message": user_messages
        }
    )

    chat_msgs.append(
        {
            "user": "Bot",
            "message": bot_messages
        }
    )

    return chat_msgs


def chatbot_handler():
    """ユーザーがメッセージを送信した際に動くハンドラー関数"""

    user_messages = ss.user_messages

    for model in ss['select_models']:

        app = get_or_init_app(model)
        start_time = time.time()
        answer = app.generate_answer(user_messages, model)['answer']
        end_time = time.time()
        elapsed_time = end_time - start_time

        bot_messages = f'{answer}\n\n生成時間: {round(elapsed_time,1)}秒'
        ss.chat_msgs = record_chat_msgs(
            model, user_messages, bot_messages
        )


def main():
    """メイン関数"""

    st.set_page_config(
        page_title="Multiple LLM Chatbot",
        layout="wide",
        initial_sidebar_state="auto",
    )

    st.title('Multiple LLM Chatbot')

    ss['select_models'] = st.multiselect(
        label="モデルを選択してください",
        options=SELECT_MODELS,
        default=SELECT_MODELS[0],
    )

    if ss['select_models']:
        columns = st.columns(len(ss['select_models']))

        for i, col in enumerate(columns):
            with col:
                model = ss['select_models'][i]
                st.text(model)
                display_chat_msgs(get_or_init_chat_msgs(model), st.container())

    ss['user_messages_ph'] = st.empty()
    ss['bot_messages_ph'] = st.empty()

    st.chat_input(
        placeholder="メッセージを入力してください",
        key="user_messages",
        on_submit=chatbot_handler,
    )


if __name__ == "__main__":
    main()

バックエンド(backend.py)

import boto3
from langchain.chains import ConversationChain
from langchain.llms import Bedrock
from langchain.callbacks.base import BaseCallbackHandler
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import DynamoDBChatMessageHistory
from langchain.chat_models import AzureChatOpenAI


class StreamHandler(BaseCallbackHandler):
    """回答をストリーミングするクラス"""

    def __init__(self, user_messages_ph, bot_messages_ph, user_messages, model_id):
        self.user_messages_ph = user_messages_ph
        self.bot_messages_ph = bot_messages_ph
        self.user_messages = user_messages
        self.generate_answer = ""
        self.model_id = model_id

    def on_llm_start(self, serialized, prompts, **kwargs):
        self.user_messages_ph.chat_message("user").write(
            f'To: {self.model_id} \n\n {self.user_messages} '
        )

    def on_llm_new_token(self, token: str, **kwargs):
        self.generate_answer += token
        self.bot_messages_ph.chat_message(
            "assistant").write(self.generate_answer)


class ConversationHistoryMemory():
    """会話を保存するメモリを初期化するクラス"""

    def __init__(self, conversation_history_db, session_id):
        self.conversation_history_db = conversation_history_db
        self.session_id = session_id

    def init_memory(self, memory_key, input_key, output_key=None):
        message_history = DynamoDBChatMessageHistory(
            table_name=self.conversation_history_db,
            session_id=self.session_id
        )

        memory = ConversationBufferWindowMemory(
            k=5,
            memory_key=memory_key,
            chat_memory=message_history,
            input_key=input_key,
            output_key=output_key,
            return_messages=True,
        )

        return memory


class GenerateAnswerLLM():
    """回答を生成するLLMを初期化するクラス"""

    def __init__(self, user_messages_ph, bot_messages_ph, user_messages, model_id):
        self.model_id = model_id
        self.callback_handler = StreamHandler(
            user_messages_ph,
            bot_messages_ph,
            user_messages,
            self.model_id
        )

    def init_generate_llm(self):

        bedrock_runtime = boto3.client(service_name="bedrock-runtime")

        generate_llm = Bedrock(
            model_id=self.model_id,
            client=bedrock_runtime,
            # model_kwargs={"max_tokens_to_sample": 4096},  // モデルごとにオプションが異なるためコメントアウト
            streaming=True,
            callbacks=[self.callback_handler]
        )

        return generate_llm

    def init_generate_google_llm(self):

        generate_llm = ChatGoogleGenerativeAI(
            model=self.model_id,
            streaming=True,
            callbacks=[self.callback_handler]
        )

        return generate_llm

    def init_generate_azure_llm(self):

        generate_llm = AzureChatOpenAI(
            azure_deployment=self.model_id,
            openai_api_version="2023-05-15",
            streaming=True,
            callbacks=[self.callback_handler]
        )

        return generate_llm


class ChatApplication():

    def __init__(self, config):
        self.user_messages_ph = config['user_messages_ph']
        self.bot_messages_ph = config['bot_messages_ph']

        self.memory_instance = ConversationHistoryMemory(
            config['conversation_history_db'],
            config['session_id']
        )
        self.memory = self.memory_instance.init_memory(
            memory_key='history',
            input_key='input'
        )

    def generate_answer(self, user_messages, model_id):
        generate_llm_instance = GenerateAnswerLLM(
            self.user_messages_ph,
            self.bot_messages_ph,
            user_messages,
            model_id
        )

        if model_id == 'gemini-pro':
            generate_llm = generate_llm_instance.init_generate_google_llm()
        elif model_id == 'gpt-35-turbo':
            generate_llm = generate_llm_instance.init_generate_azure_llm()
        else:
            generate_llm = generate_llm_instance.init_generate_llm()

        generate_llm_chain = ConversationChain(
            llm=generate_llm,
            memory=self.memory
        )

        response = generate_llm_chain.predict(input=user_messages)

        return {"answer": response}

実行方法

上記二つのファイルを同一のフォルダに配置します。

次に会話履歴を保存するようにSessionIdというキーを持ったDynamoDBを構築し、環境変数に定義します。

export CONVERSATION_HISTORY_DB=【DynamoDBのテーブル名】

また、GoogleやOpenAIに関する情報を以下のように環境変数に定義します。

export GOOGLE_API_KEY=【GoogleのAPIキー】
export OPENAI_API_KEY=【OpenAIのAPIキー】
export AZURE_OPENAI_ENDPOINT=【Azure OpenAIのエンドポイント】

その後、以下のコマンドを実行すれば、ブラウザが立ち上がるかと思います。

streamlit run frondend.py

ポイント

ストリーミング出力について

ストリーミング出力のポイントとなるコードは以下の通りです。

まずは、フロントエンド側でメッセージを出力する要素(user_messages_ph,bot_messages_ph)を定義しておきます。

## フロントエンド(frontend.py)
    ss['user_messages_ph'] = st.empty()
    ss['bot_messages_ph'] = st.empty()

次にその要素(user_messages_ph,bot_messages_ph)をChatApplicationクラスをインスタンス化する際に渡します。

## フロントエンド(frontend.py)
    config = {
        "session_id": get_or_init_session_id(model),
        "conversation_history_db": CONVERSATION_HISTORY_DB,
        "user_messages_ph": ss['user_messages_ph'],
        "bot_messages_ph": ss['bot_messages_ph']
    }

    model_app = f'{model}_app'
    if model_app not in ss:
        ss[model_app] = ChatApplication(config)

最終的にフロントエンドから渡った要素(user_messages_ph,bot_messages_ph)は、StreamHandlerクラスに渡ります。

StreamHandlerクラスでは、LLMが回答を開始するタイミングでユーザーからの質問をuser_messages_phに表示し、LLMがストリーミングで回答を生成している際に、bot_messages_phにメッセージをストリーミング出力しています。

## バックエンド(backend.py)
class StreamHandler(BaseCallbackHandler):
    """回答をストリーミングするクラス"""

    def __init__(self, user_messages_ph, bot_messages_ph, user_messages, model_id):
        self.user_messages_ph = user_messages_ph
        self.bot_messages_ph = bot_messages_ph
        self.user_messages = user_messages
        self.generate_answer = ""
        self.model_id = model_id

    def on_llm_start(self, serialized, prompts, **kwargs):
        self.user_messages_ph.chat_message("user").write(
            f'To: {self.model_id} \n\n {self.user_messages} '
        )

    def on_llm_new_token(self, token: str, **kwargs):
        self.generate_answer += token
        self.bot_messages_ph.chat_message(
            "assistant").write(self.generate_answer)

※補足 StreamlitCallbackHandlerというクラスが用意されているため、Streamlitを使ってストリーミング出力する場合は、こちらを使う方が簡単に実装できるかもしれません。

https://python.langchain.com/docs/integrations/callbacks/streamlit

会話履歴の表示

まず、質問と生成された回答は、Streamlitのセッション保持機能を使いつつ、リストの辞書型のchat_msgsに保存しています。

## フロントエンド(frontend.py)
def record_chat_msgs(model, user_messages, bot_messages):
    """会話履歴を保存する関数"""
    chat_msgs = get_or_init_chat_msgs(model)
    chat_msgs.append(
        {
            "user": "You",
            "message": user_messages
        }
    )

    chat_msgs.append(
        {
            "user": "Bot",
            "message": bot_messages
        }
    )

    return chat_msgs

そして以下の関数により、保存されているメッセージを画面に描画されます。

## フロントエンド(frontend.py)

def display_chat_msgs(chat_msgs, display_chat_msgs_ph):
    """会話履歴を表示する関数"""
    for entry in chat_msgs:
        if entry["user"] == "You":
            display_chat_msgs_ph.chat_message("user").write(entry["message"])
        elif entry["user"] == "Bot":
            display_chat_msgs_ph.chat_message(
                "assistant").write(entry["message"])
        else:
            return None

選択したモデルに応じた動的なカラムの用意

以下のコードによって選択されたLLMの数に応じて、st.columnsを用意します。 st.columns(LLM)ごとの出力は、for文内のコードによって実現されています。

## フロントエンド(frontend.py)

    if ss['select_models']:
        columns = st.columns(len(ss['select_models']))

        for i, col in enumerate(columns):
            with col:
                model = ss['select_models'][i]
                st.text(model)
                display_chat_msgs(get_or_init_chat_msgs(model), st.container())

補足

モデルを追加したい場合

SELECT_MODELSの設定を変えることで自由に変更ができます。

SELECT_MODELS = [
    "anthropic.claude-v2:1",
    "gpt-35-turbo",
    "gemini-pro",
    # "ai21.j2-ultra-v1",
    # "meta.llama2-13b-chat-v1"
    "meta.llama2-70b-chat-v1",
    # "cohere.command-light-text-v14",
    "cohere.command-text-v14",
    # "amazon.titan-text-lite-v1"
    "amazon.titan-text-express-v1",
]

モデルのデフォルトの数を変更したい場合

★の箇所を変更することで変更できます。

## フロントエンド(frontend.py)
    ss['select_models'] = st.multiselect(
        label="モデルを選択してください",
        options=SELECT_MODELS,
        default=SELECT_MODELS[0], ★
    )

終わりに

今回は興味本位でStreamlitとLangChainを活用して、1度に複数のLLMに質問してみました。 1度に複数のLLMに質問してみたいと思っている方の参考になれば幸いです。

福島 和弥 (記事一覧)

2019/10 入社

AWS CLIが好きです。