【OpenAI Embedding】精度をあげるための事前処理

ドキュメントの内容をもとにAIを使ってQAをしようとすると、
OpenAI の場合Embeddingを使用することになると思います。

ただ、このEmbeddingはただファイルをそのまま食わせてもなかなか精度が出ません。

どうしたら精度がでるかと試行錯誤した結果、何個か精度に関係しそうな事前処理を発見したのでご紹介したいと思います!

ひらべー
ひらべー

この記事はこんな人にオススメ!

・Embeddingをこれから試そうと思っている

・Embeddingを試したら精度が悪い

※ 本記事の精度に関する言及は、客観的な事実ではなく経験則からの主観的なものとなります。
ご了承の上ご参考までにご覧ください。

表形式のデータは1行1行で完結させる

先にどういうことか例示します。

まずそのままEmbeddingすると精度が悪めな例です。

| TABLE_NAME    | COLUMN_NAME    | EXPLAIN         |
| ------------- | -------------- | --------------- |
| SAMPLE_TABLE1 | SAMPLE_COLUMN1 | これはサンプルです。|
|               | SAMPLE_COLUMN2 | これはサンプルです。|
|               | SAMPLE_COLUMN3 | これはサンプルです。|
| SAMPLE_TABLE2 | SAMPLE_COLUMN1 | これはサンプルです。|

次に精度が比較的良い例です。

| TABLE_NAME    | COLUMN_NAME    | EXPLAIN         |
| ------------- | -------------- | --------------- |
| SAMPLE_TABLE1 | SAMPLE_COLUMN1 | これはサンプルです。|
| SAMPLE_TABLE1 | SAMPLE_COLUMN2 | これはサンプルです。|
| SAMPLE_TABLE1 | SAMPLE_COLUMN3 | これはサンプルです。|
| SAMPLE_TABLE2 | SAMPLE_COLUMN1 | これはサンプルです。|

前者は、TABLE_NAMEが共通なので記載を省いていますが、後者ではすべてのセルが埋まっています。
後者だとどの1行を切り取っても意味が通じるので精度が上がる傾向があります。

もしくは表を分割してしまうというのもありです。

# SAMPLE_TABLE1

| COLUMN_NAME    | EXPLAIN         |
| -------------- | --------------- |
| SAMPLE_COLUMN1 | これはサンプルです。|
| SAMPLE_COLUMN1 | これはサンプルです。|
| SAMPLE_COLUMN1 | これはサンプルです。|

# SAMPLE_TABLE2

| COLUMN_NAME    | EXPLAIN         |
| -------------- | --------------- |
| SAMPLE_COLUMN1 | これはサンプルです。|

区切り位置を制御する

Embeddingを行う際は、ドキュメントを全文そのまま学習させるのではなく、一定の文字数以内におさまるよう分割をします。
この分割の際、意味のあるまとまりで分割できるかが肝になります。

例えば、先ほどの表を100文字で分割するとこんな感じになります。

# SAMPLE_TABLE1

| COLUMN_NAME    | EXPLAIN         |
| -------------- | --------------- |
| SAMPLE_COLUMN1 | これはサンプルです。|
| SAMPLE_COLUMN1 | これはサンプルです。|
| SAMPLE_COLUMN1 | これはサンプルです。|

これでは分割されたドキュメント単独では意味をなさなくなり、精度が下がります。

なので、きちんと意味のあるブロックがまとまって分割されるように「ある程度の大きさのチャンクを設定すること」と「分割される位置を調整すること」が必要になります。

ドキュメントの分割にはLangChainが使えそうだったのですが、もう少し自分で動きをコントロールしたいと思い、
以下のような関数を実装してみました。

# 文字コードの自動判別の機能が便利なのでFile -> テキストの部分だけLangChainを使用
from langchain_community.document_loaders import TextLoader

    # separatorsで指定された文字列で順にテキストを分割する
    # chunk_sizeで指定されたサイズ以下になったらそこで分割を終了する
    def split_text_recursive(
            self,
            file: str,
            separate_regex_list: list[str],
            chunk_size: int = 2000,
    ) -> list[str]:
        split_text = [doc.page_content for doc in TextLoader(file, autodetect_encoding=True).load()]
        for separator in separate_regex_list:
            new_split_text = []
            for text_block in split_text:
                if len(text_block) > chunk_size:
                    # 例: \nで分割する場合
                    # 前: a\nb\nc
                    # 後: [a, \n, b, \n, c]
                    split_text = re.split(f'({separator})', text_block)
                    # 例:
                    # 前: [a, \n, b, \n, c]
                    # 後: [a, \nb, \nc]
                    new_split_text.append(split_text[0])
                    for i in range(2, len(split_text), 2):
                        new_split_text.append(split_text[i-1] + split_text[i])

                else:
                    new_split_text.append(text_block)
            # 細かすぎる分割は結合
            split_text = self.merge_text(new_split_text)

        return split_text

    def merge_text(self, text_list: list[str], chunk_size: int = 2000) -> list[str]:
        # chunk_sizeで指定されたサイズ以下だったら次の要素との結合を試みる
        # → 新しい配列に追加
        merged_text = ''
        merged_text_list = []
        for text in text_list:
            # 要素を結合してみて、指定サイズ以下だったら結合
            if len(merged_text) + len(text) < chunk_size:
                merged_text += text
            else:
                merged_text_list.append(merged_text)
                merged_text = text

        if merged_text:
            merged_text_list.append(merged_text)

        return merged_text_list

使い方のイメージはこんな感じです(SQLやPLSQLファイルを対象にした例)

        separators = [
            '\n\/\*\n',  # PLSQLのコメント
            '\n *CREATE ', '\n *create ',
            '\n *INSERT ', '\n *insert ',
            '\n *UPDATE ', '\n *update ',
            '\n *DELETE ', '\n *delete ',
            '\n *ALTER ', '\n *alter ',
            '\n *DROP ', '\n *drop ',
            '\n *SELECT ', '\n *select ',
            '\n *PROCEDURE ', '\n *procedure ',
            '\n *BEGIN ', '\n *begin ',
            '\n *EXCEPTION ', '\n *exception ',
        ]
        split_text = split_text_recursive(file=file, separate_regex_list=separators, chunk_size=2000)

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です