LangChain フレームワーク、IRIS Vector Search、および LLM を使用して、ユーザー プロンプトから IRIS 互換の SQL を生成する方法に関する実験。
この記事はこのノートに基づいています。 OpenExchange のこのアプリケーションを使用すると、すぐに使用できる環境で実行できます。
まず、必要なライブラリをインストールする必要があります:
!pip install --upgrade --quiet langchain langchain-openai langchain-iris pandas
次に、必要なモジュールをインポートし、環境をセットアップします。
import os import datetime import hashlib from copy import deepcopy from sqlalchemy import create_engine import getpass import pandas as pd from langchain_core.prompts import PromptTemplate, ChatPromptTemplate from langchain_core.example_selectors import SemanticSimilarityExampleSelector from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain.docstore.document import Document from langchain_community.document_loaders import DataFrameLoader from langchain.text_splitter import CharacterTextSplitter from langchain_core.output_parsers import StrOutputParser from langchain.globals import set_llm_cache from langchain.cache import SQLiteCache from langchain_iris import IRISVector
SQLiteCache を使用して LLM 呼び出しをキャッシュします。
# Cache for LLM calls set_llm_cache(SQLiteCache(database_path=".langchain.db"))
IRIS データベース接続パラメータを設定します:
# IRIS database connection parameters os.environ["ISC_LOCAL_SQL_HOSTNAME"] = "localhost" os.environ["ISC_LOCAL_SQL_PORT"] = "1972" os.environ["ISC_LOCAL_SQL_NAMESPACE"] = "IRISAPP" os.environ["ISC_LOCAL_SQL_USER"] = "_system" os.environ["ISC_LOCAL_SQL_PWD"] = "SYS"
OpenAI API キーが環境にまだ設定されていない場合は、ユーザーにそのキーを入力するよう求めます。
if not "OPENAI_API_KEY" in os.environ: os.environ["OPENAI_API_KEY"] = getpass.getpass()
IRIS データベースの接続文字列を作成します:
# IRIS database connection string args = { 'hostname': os.getenv("ISC_LOCAL_SQL_HOSTNAME"), 'port': os.getenv("ISC_LOCAL_SQL_PORT"), 'namespace': os.getenv("ISC_LOCAL_SQL_NAMESPACE"), 'username': os.getenv("ISC_LOCAL_SQL_USER"), 'password': os.getenv("ISC_LOCAL_SQL_PWD") } iris_conn_str = f"iris://{args['username']}:{args['password']}@{args['hostname']}:{args['port']}/{args['namespace']}"
IRIS データベースへの接続を確立します:
# Connection to IRIS database engine = create_engine(iris_conn_str) cnx = engine.connect().connection
システム プロンプトのコンテキスト情報を保持する辞書を準備します。
# Dict for context information for system prompt context = {} context["top_k"] = 3
ユーザー入力を IRIS データベースと互換性のある SQL クエリに変換するには、言語モデル用の効果的なプロンプトを作成する必要があります。 SQL クエリを生成するための基本的な手順を示す最初のプロンプトから始めます。このテンプレートは、MSSQL 用の LangChain のデフォルト プロンプトから派生し、IRIS データベース用にカスタマイズされています。
# Basic prompt template with IRIS database SQL instructions iris_sql_template = """ You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database. Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers. Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today". Use double quotes to delimit columns identifiers. Return just plain SQL; don't apply any kind of formatting. """
この基本的なプロンプトは、IRIS データベースに対する具体的なガイダンスを備えた SQL エキスパートとして機能するように言語モデル (LLM) を構成します。次に、幻覚を避けるために、データベース スキーマに関する情報を含む補助プロンプトを提供します。
# SQL template extension for including tables context information tables_prompt_template = """ Only use the following tables: {table_info} """
LLM の応答の精度を高めるために、少数ショット プロンプトと呼ばれる手法を使用します。これには、LLM にいくつかの例を提示することが含まれます。
# SQL template extension for including few shots prompt_sql_few_shots_template = """ Below are a number of examples of questions and their corresponding SQL queries. {examples_value} """
数ショットの例のテンプレートを定義します。
# Few shots prompt template example_prompt_template = "User input: {input}\nSQL query: {query}" example_prompt = PromptTemplate.from_template(example_prompt_template)
少数ショット テンプレートを使用してユーザー プロンプトを作成します。
# User prompt template user_prompt = "\n" + example_prompt.invoke({"input": "{input}", "query": ""}).to_string()
最後に、すべてのプロンプトを作成して最後のプロンプトを作成します。
# Complete prompt template prompt = ( ChatPromptTemplate.from_messages([("system", iris_sql_template)]) + ChatPromptTemplate.from_messages([("system", tables_prompt_template)]) + ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)]) + ChatPromptTemplate.from_messages([("human", user_prompt)]) ) prompt
このプロンプトでは、変数 example_value、input、table_info、および top_k が必要です。
プロンプトの構造は次のとおりです。
ChatPromptTemplate( input_variables=['examples_value', 'input', 'table_info', 'top_k'], messages=[ SystemMessagePromptTemplate( prompt=PromptTemplate( input_variables=['top_k'], template=iris_sql_template ) ), SystemMessagePromptTemplate( prompt=PromptTemplate( input_variables=['table_info'], template=tables_prompt_template ) ), SystemMessagePromptTemplate( prompt=PromptTemplate( input_variables=['examples_value'], template=prompt_sql_few_shots_template ) ), HumanMessagePromptTemplate( prompt=PromptTemplate( input_variables=['input'], template=user_prompt ) ) ] )
プロンプトが LLM にどのように送信されるかを視覚化するには、必要な変数のプレースホルダー値を使用できます。
prompt_value = prompt.invoke({ "top_k": "<top_k>", "table_info": "<table_info>", "examples_value": "<examples_value>", "input": "<input>" }) print(prompt_value.to_string())
System: You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question. Unless the user specifies in the question a specific number of examples to obtain, query for at most <top_k> results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database. Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers. Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today". Use double quotes to delimit columns identifiers. Return just plain SQL; don't apply any kind of formatting. System: Only use the following tables: <table_info> System: Below are a number of examples of questions and their corresponding SQL queries. <examples_value> Human: User input: <input> SQL query:
これで、必要な変数を指定して、このプロンプトを LLM に送信する準備が整いました。準備ができたら、次のステップに進みましょう。
正確な SQL クエリを作成するには、データベース テーブルに関する詳細情報を言語モデル (LLM) に提供する必要があります。この情報がないと、LLM はもっともらしいクエリを生成する可能性がありますが、幻覚のために正しくない可能性があります。したがって、最初のステップは、IRIS データベースからテーブル定義を取得する関数を作成することです。
次の関数は、INFORMATION_SCHEMA をクエリして、指定されたスキーマのテーブル定義を取得します。特定のテーブルが指定されている場合は、そのテーブルの定義を取得します。それ以外の場合は、スキーマ内のすべてのテーブルの定義を取得します。
def get_table_definitions_array(cnx, schema, table=None): cursor = cnx.cursor() # Base query to get columns information query = """ SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, PRIMARY_KEY, null EXTRA FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = %s """ # Parameters for the query params = [schema] # Adding optional filters if table: query += " AND TABLE_NAME = %s" params.append(table) # Execute the query cursor.execute(query, params) # Fetch the results rows = cursor.fetchall() # Process the results to generate the table definition(s) table_definitions = {} for row in rows: table_schema, table_name, column_name, column_type, is_nullable, column_default, column_key, extra = row if table_name not in table_definitions: table_definitions[table_name] = [] table_definitions[table_name].append({ "column_name": column_name, "column_type": column_type, "is_nullable": is_nullable, "column_default": column_default, "column_key": column_key, "extra": extra }) primary_keys = {} # Build the output string result = [] for table_name, columns in table_definitions.items(): table_def = f"CREATE TABLE {schema}.{table_name} (\n" column_definitions = [] for column in columns: column_def = f" {column['column_name']} {column['column_type']}" if column['is_nullable'] == "NO": column_def += " NOT NULL" if column['column_default'] is not None: column_def += f" DEFAULT {column['column_default']}" if column['extra']: column_def += f" {column['extra']}" column_definitions.append(column_def) if table_name in primary_keys: pk_def = f" PRIMARY KEY ({', '.join(primary_keys[table_name])})" column_definitions.append(pk_def) table_def += ",\n".join(column_definitions) table_def += "\n);" result.append(table_def) return result
この例では、ここから入手可能な Aviation スキーマを使用します。
# Retrieve table definitions for the Aviation schema tables = get_table_definitions_array(cnx, "Aviation") print(tables)
この関数は、Aviation スキーマ内のすべてのテーブルの CREATE TABLE ステートメントを返します。
[ 'CREATE TABLE Aviation.Aircraft (\n Event bigint NOT NULL,\n ID varchar NOT NULL,\n AccidentExplosion varchar,\n AccidentFire varchar,\n AirFrameHours varchar,\n AirFrameHoursSince varchar,\n AirFrameHoursSinceLastInspection varchar,\n AircraftCategory varchar,\n AircraftCertMaxGrossWeight integer,\n AircraftHomeBuilt varchar,\n AircraftKey integer NOT NULL,\n AircraftManufacturer varchar,\n AircraftModel varchar,\n AircraftRegistrationClass varchar,\n AircraftSerialNo varchar,\n AircraftSeries varchar,\n Damage varchar,\n DepartureAirportId varchar,\n DepartureCity varchar,\n DepartureCountry varchar,\n DepartureSameAsEvent varchar,\n DepartureState varchar,\n DepartureTime integer,\n DepartureTimeZone varchar,\n DestinationAirportId varchar,\n DestinationCity varchar,\n DestinationCountry varchar,\n DestinationSameAsLocal varchar,\n DestinationState varchar,\n EngineCount integer,\n EvacuationOccurred varchar,\n EventId varchar NOT NULL,\n FlightMedical varchar,\n FlightMedicalType varchar,\n FlightPhase integer,\n FlightPlan varchar,\n FlightPlanActivated varchar,\n FlightSiteSeeing varchar,\n FlightType varchar,\n GearType varchar,\n LastInspectionDate timestamp,\n LastInspectionType varchar,\n Missing varchar,\n OperationDomestic varchar,\n OperationScheduled varchar,\n OperationType varchar,\n OperatorCertificate varchar,\n OperatorCertificateNum varchar,\n OperatorCode varchar,\n OperatorCountry varchar,\n OperatorIndividual varchar,\n OperatorName varchar,\n OperatorState varchar,\n Owner varchar,\n OwnerCertified varchar,\n OwnerCountry varchar,\n OwnerState varchar,\n RegistrationNumber varchar,\n ReportedToICAO varchar,\n SeatsCabinCrew integer,\n SeatsFlightCrew integer,\n SeatsPassengers integer,\n SeatsTotal integer,\n SecondPilot varchar,\n childsub bigint NOT NULL DEFAULT $i(^Aviation.EventC("Aircraft"))\n);', 'CREATE TABLE Aviation.Crew (\n Aircraft varchar NOT NULL,\n ID varchar NOT NULL,\n Age integer,\n AircraftKey integer NOT NULL,\n Category varchar,\n CrewNumber integer NOT NULL,\n EventId varchar NOT NULL,\n Injury varchar,\n MedicalCertification varchar,\n MedicalCertificationDate timestamp,\n MedicalCertificationValid varchar,\n Seat varchar,\n SeatbeltUsed varchar,\n Sex varchar,\n ShoulderHarnessUsed varchar,\n ToxicologyTestPerformed varchar,\n childsub bigint NOT NULL DEFAULT $i(^Aviation.AircraftC("Crew"))\n);', 'CREATE TABLE Aviation.Event (\n ID bigint NOT NULL DEFAULT $i(^Aviation.EventD),\n AirportDirection integer,\n AirportDistance varchar,\n AirportElevation integer,\n AirportLocation varchar,\n AirportName varchar,\n Altimeter varchar,\n EventDate timestamp,\n EventId varchar NOT NULL,\n EventTime integer,\n FAADistrictOffice varchar,\n InjuriesGroundFatal integer,\n InjuriesGroundMinor integer,\n InjuriesGroundSerious integer,\n InjuriesHighest varchar,\n InjuriesTotal integer,\n InjuriesTotalFatal integer,\n InjuriesTotalMinor integer,\n InjuriesTotalNone integer,\n InjuriesTotalSerious integer,\n InvestigatingAgency varchar,\n LightConditions varchar,\n LocationCity varchar,\n LocationCoordsLatitude double,\n LocationCoordsLongitude double,\n LocationCountry varchar,\n LocationSiteZipCode varchar,\n LocationState varchar,\n MidAir varchar,\n NTSBId varchar,\n NarrativeCause varchar,\n NarrativeFull varchar,\n NarrativeSummary varchar,\n OnGroundCollision varchar,\n SkyConditionCeiling varchar,\n SkyConditionCeilingHeight integer,\n SkyConditionNonCeiling varchar,\n SkyConditionNonCeilingHeight integer,\n TimeZone varchar,\n Type varchar,\n Visibility varchar,\n WeatherAirTemperature integer,\n WeatherPrecipitation varchar,\n WindDirection integer,\n WindDirectionIndicator varchar,\n WindGust integer,\n WindGustIndicator varchar,\n WindVelocity integer,\n WindVelocityIndicator varchar\n);' ]
これらのテーブル定義を使用して、LLM のプロンプトにテーブル定義を統合する次のステップに進むことができます。これにより、LLM は SQL クエリを生成するときに、データベース スキーマに関する正確かつ包括的な情報を確実に得ることができます。
データベース、特に大規模なデータベースを操作する場合、プロンプト内のすべてのテーブルに対してデータ定義言語 (DDL) を送信するのは現実的ではない場合があります。このアプローチは小規模なデータベースでは機能する可能性がありますが、実際のデータベースには数百または数千のテーブルが含まれることが多く、そのすべてを処理するのは非効率的です。
さらに、SQL クエリを効果的に生成するために言語モデルがデータベース内のすべてのテーブルを認識する必要があるとは考えにくいです。この課題に対処するために、セマンティック検索機能を活用して、ユーザーのクエリに基づいて最も関連性の高いテーブルのみを選択できます。
これは、IRIS Vector Search によるセマンティック検索を使用することで実現されます。この方法は、SQL 要素識別子 (テーブル、フィールド、キーなど) に意味のある名前が付いている場合に最も効果的であることに注意してください。識別子が任意のコードである場合は、代わりにデータ ディクショナリの使用を検討してください。
まず、テーブル定義を pandas DataFrame に抽出します。
# Retrieve table definitions into a pandas DataFrame table_def = get_table_definitions_array(cnx=cnx, schema='Aviation') table_df = pd.DataFrame(data=table_def, columns=["col_def"]) table_df["id"] = table_df.index + 1 table_df
The DataFrame (table_df) will look something like this:
col_def | id | |
---|---|---|
0 | CREATE TABLE Aviation.Aircraft (\n Event bigi... | 1 |
1 | CREATE TABLE Aviation.Crew (\n Aircraft varch... | 2 |
2 | CREATE TABLE Aviation.Event (\n ID bigint NOT... | 3 |
Next, split the table definitions into Langchain Documents. This step is crucial for handling large chunks of text and extracting text embeddings:
loader = DataFrameLoader(table_df, page_content_column="col_def") documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n") tables_docs = text_splitter.split_documents(documents) tables_docs
The resulting tables_docs list contains split documents with metadata, like so:
[Document(metadata={'id': 1}, page_content='CREATE TABLE Aviation.Aircraft (\n Event bigint NOT NULL,\n ID varchar NOT NULL,\n ...'), Document(metadata={'id': 2}, page_content='CREATE TABLE Aviation.Crew (\n Aircraft varchar NOT NULL,\n ID varchar NOT NULL,\n ...'), Document(metadata={'id': 3}, page_content='CREATE TABLE Aviation.Event (\n ID bigint NOT NULL DEFAULT $i(^Aviation.EventD),\n ...')]
Now, use the IRISVector class from langchain-iris to extract embedding vectors and store them:
tables_vector_store = IRISVector.from_documents( embedding=OpenAIEmbeddings(), documents=tables_docs, connection_string=iris_conn_str, collection_name="sql_tables", pre_delete_collection=True )
Note: The pre_delete_collection flag is set to True for demonstration purposes to ensure a fresh collection in each test run. In a production environment, this flag should generally be set to False.
With the table embeddings stored, you can now query for relevant tables based on user input:
input_query = "List the first 2 manufacturers" relevant_tables_docs = tables_vector_store.similarity_search(input_query, k=3) relevant_tables_docs
For example, querying for manufacturers might return:
[Document(metadata={'id': 1}, page_content='GearType varchar,\n LastInspectionDate timestamp,\n ...'), Document(metadata={'id': 1}, page_content='AircraftModel varchar,\n AircraftRegistrationClass varchar,\n ...'), Document(metadata={'id': 3}, page_content='LocationSiteZipCode varchar,\n LocationState varchar,\n ...')]
From the metadata, you can see that only table ID 1 (Aviation.Aircraft) is relevant, which aligns with the query.
While this approach is generally effective, it may not always be perfect. For instance, querying for crash sites might also return less relevant tables:
input_query = "List the top 10 most crash sites" relevant_tables_docs = tables_vector_store.similarity_search(input_query, k=3) relevant_tables_docs
Results might include:
[Document(metadata={'id': 3}, page_content='LocationSiteZipCode varchar,\n LocationState varchar,\n ...'), Document(metadata={'id': 3}, page_content='InjuriesGroundSerious integer,\n InjuriesHighest varchar,\n ...'), Document(metadata={'id': 1}, page_content='CREATE TABLE Aviation.Aircraft (\n Event bigint NOT NULL,\n ID varchar NOT NULL,\n ...')]
Despite retrieving the correct Aviation.Event table twice, the Aviation.Aircraft table may also appear, which could be improved with additional filtering or thresholding. This is beyond the scope of this example and will be left for future implementations.
To automate this process, define a function to filter and return the relevant tables based on user input:
def get_relevant_tables(user_input, tables_vector_store, table_df): relevant_tables_docs = tables_vector_store.similarity_search(user_input) relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs] indices = table_df["id"].isin(relevant_tables_docs_indices) relevant_tables_array = [x for x in table_df[indices]["col_def"]] return relevant_tables_array
This function will help in efficiently retrieving only the relevant tables to send to the LLM, reducing the prompt length and improving overall query performance.
When working with language models (LLMs), providing them with relevant examples helps ensure accurate and contextually appropriate responses. These examples, referred to as "few-shot" examples, guide the LLM in understanding the structure and context of the queries it should handle.
In our case, we need to populate the examples_value variable with a diverse set of SQL queries that cover a broad spectrum of IRIS SQL syntax and the tables available in the database. This helps prevent the LLM from generating incorrect or irrelevant queries.
Below is a list of example queries designed to illustrate various SQL operations:
examples = [ {"input": "List all aircrafts.", "query": "SELECT * FROM Aviation.Aircraft"}, {"input": "Find all incidents for the aircraft with ID 'N12345'.", "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"}, {"input": "List all incidents in the 'Commercial' operation type.", "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE OperationType = 'Commercial')"}, {"input": "Find the total number of incidents.", "query": "SELECT COUNT(*) FROM Aviation.Event"}, {"input": "List all incidents that occurred in 'Canada'.", "query": "SELECT * FROM Aviation.Event WHERE LocationCountry = 'Canada'"}, {"input": "How many incidents are associated with the aircraft with AircraftKey 5?", "query": "SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5"}, {"input": "Find the total number of distinct aircrafts involved in incidents.", "query": "SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft"}, {"input": "List all incidents that occurred after 5 PM.", "query": "SELECT * FROM Aviation.Event WHERE EventTime > 1700"}, {"input": "Who are the top 5 operators by the number of incidents?", "query": "SELECT TOP 5 OperatorName, COUNT(*) AS IncidentCount FROM Aviation.Aircraft GROUP BY OperatorName ORDER BY IncidentCount DESC"}, {"input": "Which incidents occurred in the year 2020?", "query": "SELECT * FROM Aviation.Event WHERE YEAR(EventDate) = '2020'"}, {"input": "What was the month with most events in the year 2020?", "query": "SELECT TOP 1 MONTH(EventDate) EventMonth, COUNT(*) EventCount FROM Aviation.Event WHERE YEAR(EventDate) = '2020' GROUP BY MONTH(EventDate) ORDER BY EventCount DESC"}, {"input": "How many crew members were involved in incidents?", "query": "SELECT COUNT(*) FROM Aviation.Crew"}, {"input": "List all incidents with detailed aircraft information for incidents that occurred in the year 2012.", "query": "SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012"}, {"input": "Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.", "query": "SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5"}, {"input": "List all crew members involved in incidents with serious injuries, along with the incident date and location.", "query": "SELECT c.CrewNumber AS 'Crew Number', c.Age, c.Sex AS Gender, e.EventDate AS 'Event Date', e.LocationCity AS 'Location City', e.LocationState AS 'Location State' FROM Aviation.Crew c JOIN Aviation.Event e ON c.EventId = e.EventId WHERE c.Injury = 'Serious'"} ]
Given the ever-expanding list of examples, it’s impractical to provide the LLM with all of them. Instead, we use IRIS Vector Search along with the SemanticSimilarityExampleSelector class to identify the most relevant examples based on user prompts.
example_selector = SemanticSimilarityExampleSelector.from_examples( examples, OpenAIEmbeddings(), IRISVector, k=5, input_keys=["input"], connection_string=iris_conn_str, collection_name="sql_samples", pre_delete_collection=True )
Note: The pre_delete_collection flag is used here for demonstration purposes to ensure a fresh collection in each test run. In a production environment, this flag should be set to False to avoid unnecessary deletions.
To find the most relevant examples for a given input, use the selector as follows:
input_query = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model." relevant_examples = example_selector.select_examples({"input": input_query})
The results might look like this:
[{'input': 'List all incidents with detailed aircraft information for incidents that occurred in the year 2012.', 'query': 'SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012'}, {'input': "Find all incidents for the aircraft with ID 'N12345'.", 'query': "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"}, {'input': 'Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.', 'query': 'SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5'}, {'input': 'List all aircrafts.', 'query': 'SELECT * FROM Aviation.Aircraft'}, {'input': 'Find the total number of distinct aircrafts involved in incidents.', 'query': 'SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft'}]
If you specifically need examples related to quantities, you can query the selector accordingly:
input_query = "What is the number of incidents involving Boeing aircraft." quantity_examples = example_selector.select_examples({"input": input_query})
The output may be:
[{'input': 'How many incidents are associated with the aircraft with AircraftKey 5?', 'query': 'SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5'}, {'input': 'Find the total number of distinct aircrafts involved in incidents.', 'query': 'SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft'}, {'input': 'How many crew members were involved in incidents?', 'query': 'SELECT COUNT(*) FROM Aviation.Crew'}, {'input': 'Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.', 'query': 'SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5'}, {'input': 'List all incidents with detailed aircraft information for incidents that occurred in the year 2012.', 'query': 'SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012'}]
This output includes examples that specifically address counting and quantities.
While the SemanticSimilarityExampleSelector is powerful, it’s important to note that not all selected examples may be perfect. Future improvements may involve adding filters or thresholds to exclude less relevant results, ensuring that only the most appropriate examples are provided to the LLM.
To assess the performance of the prompt and SQL query generation, we need to set up and run a series of tests. The goal is to evaluate how well the LLM generates SQL queries based on user inputs, with and without the use of example-based few shots.
We start by defining a function that uses the LLM to generate SQL queries based on the provided context, prompt, user input, and other parameters:
def get_sql_from_text(context, prompt, user_input, use_few_shots, tables_vector_store, table_df, example_selector=None, example_prompt=None): relevant_tables = get_relevant_tables(user_input, tables_vector_store, table_df) context["table_info"] = "\n\n".join(relevant_tables) examples = example_selector.select_examples({"input": user_input}) if example_selector else [] context["examples_value"] = "\n\n".join([ example_prompt.invoke(x).to_string() for x in examples ]) model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) output_parser = StrOutputParser() chain_model = prompt | model | output_parser response = chain_model.invoke({ "top_k": context["top_k"], "table_info": context["table_info"], "examples_value": context["examples_value"], "input": user_input }) return response
Test the prompt with and without examples:
# Prompt execution **with** few shots input = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model." response_with_few_shots = get_sql_from_text( context, prompt, user_input=input, use_few_shots=True, tables_vector_store=tables_vector_store, table_df=table_df, example_selector=example_selector, example_prompt=example_prompt, ) print(response_with_few_shots)
SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2010
# Prompt execution **without** few shots input = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model." response_with_no_few_shots = get_sql_from_text( context, prompt, user_input=input, use_few_shots=False, tables_vector_store=tables_vector_store, table_df=table_df, ) print(response_with_no_few_shots)
SELECT TOP 3 "EventId", "EventDate", "LocationCity", "LocationState", "AircraftManufacturer", "AircraftModel" FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.ID = a.Event WHERE e.EventDate >= '2010-01-01' AND e.EventDate < '2011-01-01'
To test the generated SQL queries, we define some utility functions:
def execute_sql_query(cnx, query): try: cursor = cnx.cursor() cursor.execute(query) rows = cursor.fetchall() return rows except: print('Error running query:') print(query) print('-'*80) return None def sql_result_equals(cnx, query, expected): rows = execute_sql_query(cnx, query) result = [set(row._asdict().values()) for row in rows or []] if result != expected and rows is not None: print('Result not as expected for query:') print(query) print('-'*80) return result == expected
# SQL test for prompt **with** few shots print("SQL is OK" if not execute_sql_query(cnx, response_with_few_shots) is None else "SQL is not OK")
SQL is OK
# SQL test for prompt **without** few shots print("SQL is OK" if not execute_sql_query(cnx, response_with_no_few_shots) is None else "SQL is not OK")
error on running query: SELECT TOP 3 "EventId", "EventDate", "LocationCity", "LocationState", "AircraftManufacturer", "AircraftModel" FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.ID = a.Event WHERE e.EventDate >= '2010-01-01' AND e.EventDate < '2011-01-01' -------------------------------------------------------------------------------- SQL is not OK
Define a set of test cases and run them:
tests = [{ "input": "What were the top 3 years with the most recorded events?", "expected": [{128, 2003}, {122, 2007}, {117, 2005}] },{ "input": "How many incidents involving Boeing aircraft.", "expected": [{5}] },{ "input": "How many incidents that resulted in fatalities.", "expected": [{237}] },{ "input": "List event Id and date and, crew number, age and gender for incidents that occurred in 2013.", "expected": [{1, datetime.datetime(2013, 3, 4, 11, 6), '20130305X71252', 59, 'M'}, {1, datetime.datetime(2013, 1, 1, 15, 0), '20130101X94035', 32, 'M'}, {2, datetime.datetime(2013, 1, 1, 15, 0), '20130101X94035', 35, 'M'}, {1, datetime.datetime(2013, 1, 12, 15, 0), '20130113X42535', 25, 'M'}, {2, datetime.datetime(2013, 1, 12, 15, 0), '20130113X42535', 34, 'M'}, {1, datetime.datetime(2013, 2, 1, 15, 0), '20130203X53401', 29, 'M'}, {1, datetime.datetime(2013, 2, 15, 15, 0), '20130218X70747', 27, 'M'}, {1, datetime.datetime(2013, 3, 2, 15, 0), '20130303X21011', 49, 'M'}, {1, datetime.datetime(2013, 3, 23, 13, 52), '20130326X85150', 'M', None}] },{ "input": "Find the total number of incidents that occurred in the United States.", "expected": [{1178}] },{ "input": "List all incidents latitude and longitude coordinates with more than 5 injuries that occurred in 2010.", "expected": [{-78.76833333333333, 43.25277777777778}] },{ "input": "Find all incidents in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model.", "expected": [ {datetime.datetime(2010, 5, 20, 13, 43), '20100520X60222', 'CIRRUS DESIGN CORP', 'Farmingdale', 'New York', 'SR22'}, {datetime.datetime(2010, 4, 11, 15, 0), '20100411X73253', 'CZECH AIRCRAFT WORKS SPOL SRO', 'Millbrook', 'New York', 'SPORTCRUISER'}, {'108', datetime.datetime(2010, 1, 9, 12, 55), '20100111X41106', 'Bayport', 'New York', 'STINSON'}, {datetime.datetime(2010, 8, 1, 14, 20), '20100801X85218', 'A185F', 'CESSNA', 'New York', 'Newfane'} ] }]
Run the tests and calculate the accuracy:
def execute_tests(cnx, context, prompt, use_few_shots, tables_vector_store, table_df, example_selector, example_prompt): tests_generated_sql = [(x, get_sql_from_text( context, prompt, user_input=x['input'], use_few_shots=use_few_shots, tables_vector_store=tables_vector_store, table_df=table_df, example_selector=example_selector if use_few_shots else None, example_prompt=example_prompt if use_few_shots else None, )) for x in deepcopy(tests)] tests_sql_executions = [(x[0], sql_result_equals(cnx, x[1], x[0]['expected'])) for x in tests_generated_sql] accuracy = sum(1 for i in tests_sql_executions if i[1] == True) / len(tests_sql_executions) print(f'Accuracy: {accuracy}') print('-'*80)
# Accuracy tests for prompts executed **without** few shots use_few_shots = False execute_tests( cnx, context, prompt, use_few_shots, tables_vector_store, table_df, example_selector, example_prompt )
error on running query: SELECT "EventDate", COUNT("EventId") as "TotalEvents" FROM Aviation.Event GROUP BY "EventDate" ORDER BY "TotalEvents" DESC TOP 3; -------------------------------------------------------------------------------- error on running query: SELECT "EventId", "EventDate", "C"."CrewNumber", "C"."Age", "C"."Sex" FROM "Aviation.Event" AS "E" JOIN "Aviation.Crew" AS "C" ON "E"."ID" = "C"."EventId" WHERE "E"."EventDate" >= '2013-01-01' AND "E"."EventDate" < '2014-01-01' -------------------------------------------------------------------------------- result not expected for query: SELECT TOP 3 "e"."EventId", "e"."EventDate", "e"."LocationCity", "e"."LocationState", "a"."AircraftManufacturer", "a"."AircraftModel" FROM "Aviation"."Event" AS "e" JOIN "Aviation"."Aircraft" AS "a" ON "e"."ID" = "a"."Event" WHERE "e"."EventDate" >= '2010-01-01' AND "e"."EventDate" < '2011-01-01' -------------------------------------------------------------------------------- accuracy: 0.5714285714285714 --------------------------------------------------------------------------------
# Accuracy tests for prompts executed **with** few shots use_few_shots = True execute_tests( cnx, context, prompt, use_few_shots, tables_vector_store, table_df, example_selector, example_prompt )
error on running query: SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2010 TOP 3 -------------------------------------------------------------------------------- accuracy: 0.8571428571428571 --------------------------------------------------------------------------------
The accuracy of SQL queries generated with examples (few shots) is approximately 49% higher compared to those generated without examples (85% vs. 57%).
以上がLangChain を使用して IRIS SQL にテキストを送信しますの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。