organize things
This commit is contained in:
112
poller/main.py
112
poller/main.py
@@ -1,5 +1,6 @@
|
||||
from pgvector.psycopg import register_vector, Bit
|
||||
from urllib.parse import quote
|
||||
from psycopg.rows import dict_row
|
||||
from urllib.parse import unquote
|
||||
import anthropic
|
||||
import cohere
|
||||
import dotenv
|
||||
@@ -24,24 +25,19 @@ conn = psycopg.connect(
|
||||
dbname='cavepediav2_db',
|
||||
user='cavepediav2_user',
|
||||
password='cavepediav2_pw',
|
||||
row_factory=dict_row,
|
||||
)
|
||||
|
||||
## init
|
||||
def create_tables():
|
||||
commands = (
|
||||
"CREATE EXTENSION IF NOT EXISTS vector",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
key TEXT PRIMARY KEY,
|
||||
value JSONB
|
||||
)
|
||||
""",
|
||||
"DROP TABLE IF EXISTS embeddings",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS embeddings (
|
||||
bucket TEXT,
|
||||
key TEXT,
|
||||
content TEXT,
|
||||
embedding bit(1536),
|
||||
embedding vector(1536),
|
||||
PRIMARY KEY (bucket, key)
|
||||
)
|
||||
""")
|
||||
@@ -50,50 +46,7 @@ def create_tables():
|
||||
conn.commit()
|
||||
register_vector(conn)
|
||||
|
||||
def insert_text(bucket, key, text):
|
||||
with conn.cursor() as cur:
|
||||
command = 'INSERT INTO embeddings (bucket, key, content) VALUES (%s, %s, %s);'
|
||||
cur.execute(command, (bucket, key, text))
|
||||
conn.commit()
|
||||
|
||||
def process_events():
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('SELECT * FROM events')
|
||||
|
||||
for row in cur.fetchall():
|
||||
for record in row[1]['Records']:
|
||||
bucket = record['s3']['bucket']['name']
|
||||
key = record['s3']['object']['key']
|
||||
|
||||
ai = extract_data(bucket, key)
|
||||
text = ai.content[0].text
|
||||
text = text.replace('\n',' ')
|
||||
insert_text(bucket, key, text)
|
||||
|
||||
# https://github.com/pgvector/pgvector-python/blob/master/examples/cohere/example.py
|
||||
def embed(text, input_type):
|
||||
resp = co.embed(
|
||||
texts=[text],
|
||||
model='embed-v4.0',
|
||||
input_type=input_type,
|
||||
embedding_types=['ubinary'],
|
||||
)
|
||||
return [np.unpackbits(np.array(embedding, dtype=np.uint8)) for embedding in resp.embeddings.ubinary]
|
||||
|
||||
def generate_embeddings():
|
||||
cur = conn.cursor()
|
||||
cur.execute('SELECT * FROM embeddings WHERE embedding IS NULL')
|
||||
rows = cur.fetchall()
|
||||
|
||||
for row in rows:
|
||||
inputs = ['mycontent']
|
||||
embeddings=embed(row[2], 'search_document')
|
||||
|
||||
for content, embedding in zip(inputs, embeddings):
|
||||
conn.execute('INSERT INTO embeddings (bucket, key, content, embedding) VALUES (%s, %s, %s, %s)', ('mybucket', 'mykey', content, Bit(embedding).to_text()))
|
||||
conn.commit()
|
||||
|
||||
# sql = 'UPDATE embeddings SET embedding = %s', (Bit(embeddings[0]))
|
||||
## processing
|
||||
def get_presigned_url(bucket, key) -> str:
|
||||
client = minio.Minio(
|
||||
's3.bigcavemaps.com',
|
||||
@@ -102,7 +55,7 @@ def get_presigned_url(bucket, key) -> str:
|
||||
region='kansascity',
|
||||
)
|
||||
|
||||
url = client.presigned_get_object(bucket, key)
|
||||
url = client.presigned_get_object(bucket, unquote(key))
|
||||
return url
|
||||
|
||||
def extract_data(bucket, key):
|
||||
@@ -135,16 +88,49 @@ def extract_data(bucket, key):
|
||||
)
|
||||
return message
|
||||
|
||||
def search():
|
||||
query = 'door'
|
||||
query_embedding = embed(query, 'search_query')[0]
|
||||
def process_events():
|
||||
rows = conn.execute('SELECT * FROM events')
|
||||
|
||||
rows = conn.execute('SELECT content FROM embeddings ORDER BY embedding <~> %s LIMIT 5', (Bit(query_embedding).to_text(),)).fetchall()
|
||||
for row in rows:
|
||||
print(row)
|
||||
for record in row['event_data']['Records']:
|
||||
bucket = record['s3']['bucket']['name']
|
||||
key = record['s3']['object']['key']
|
||||
print(f'PROCESSING event_time: {row["event_time"]}, bucket: {bucket}, key: {key}')
|
||||
print()
|
||||
|
||||
ai_ocr = extract_data(bucket, key)
|
||||
text = ai_ocr.content[0].text
|
||||
text = text.replace('\n',' ')
|
||||
|
||||
with conn.cursor() as cur:
|
||||
sql = 'INSERT INTO embeddings (bucket, key, content) VALUES (%s, %s, %s);'
|
||||
cur.execute(sql, (bucket, key, text))
|
||||
cur.execute('DELETE FROM events WHERE event_time = %s', (row['event_time'],))
|
||||
conn.commit()
|
||||
|
||||
### embeddings
|
||||
# https://github.com/pgvector/pgvector-python/blob/master/examples/cohere/example.py
|
||||
def embed(text, input_type):
|
||||
resp = co.embed(
|
||||
texts=[text],
|
||||
model='embed-v4.0',
|
||||
input_type=input_type,
|
||||
embedding_types=['float'],
|
||||
)
|
||||
return resp.embeddings.float[0]
|
||||
|
||||
def generate_embeddings():
|
||||
cur = conn.cursor()
|
||||
cur.execute('SELECT * FROM embeddings WHERE embedding IS NULL')
|
||||
rows = cur.fetchall()
|
||||
|
||||
for row in rows:
|
||||
embedding=embed(row['content'], 'search_document')
|
||||
|
||||
conn.execute('UPDATE embeddings SET embedding = %s::vector WHERE bucket = %s AND key = %s', (embedding, row['bucket'], row['key']))
|
||||
conn.commit()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# create_tables()
|
||||
# process_events()
|
||||
# generate_embeddings()
|
||||
search()
|
||||
create_tables()
|
||||
process_events()
|
||||
generate_embeddings()
|
||||
|
||||
49
poller/search.py
Normal file
49
poller/search.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pgvector.psycopg import register_vector, Bit
|
||||
from psycopg.rows import dict_row
|
||||
from urllib.parse import unquote
|
||||
import anthropic
|
||||
import cohere
|
||||
import dotenv
|
||||
import datetime
|
||||
import json
|
||||
import minio
|
||||
import numpy as np
|
||||
import os
|
||||
import psycopg
|
||||
import time
|
||||
|
||||
dotenv.load_dotenv('/home/paul/scripts-private/lech/cavepedia-v2/poller.env')
|
||||
|
||||
COHERE_API_KEY = os.getenv('COHERE_API_KEY')
|
||||
|
||||
co = cohere.ClientV2(COHERE_API_KEY)
|
||||
conn = psycopg.connect(
|
||||
host='127.0.0.1',
|
||||
port=4010,
|
||||
dbname='cavepediav2_db',
|
||||
user='cavepediav2_user',
|
||||
password='cavepediav2_pw',
|
||||
row_factory=dict_row,
|
||||
)
|
||||
|
||||
def embed(text, input_type):
|
||||
resp = co.embed(
|
||||
texts=[text],
|
||||
model='embed-v4.0',
|
||||
input_type=input_type,
|
||||
embedding_types=['float'],
|
||||
)
|
||||
return resp.embeddings.float[0]
|
||||
|
||||
def search():
|
||||
query = 'sex'
|
||||
query_embedding = embed(query, 'search_query')
|
||||
|
||||
rows = conn.execute('SELECT * FROM embeddings ORDER BY embedding <=> %s::vector LIMIT 5', (query_embedding,)).fetchall()
|
||||
for row in rows:
|
||||
print(row['bucket'])
|
||||
print(row['key'])
|
||||
print(row['content'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
search()
|
||||
@@ -1,6 +0,0 @@
|
||||
--2025-05-24 18:17:39-- https://s3.bigcavemaps.com:9000/arn:aws:s3:::cavepedia-v2/doorwarning.pdf?X-Amz-Algorithm=AWS4-HMAC-SHA256
|
||||
Resolving s3.bigcavemaps.com (s3.bigcavemaps.com)... 2606:d640:0:10::2, 104.167.221.74
|
||||
Connecting to s3.bigcavemaps.com (s3.bigcavemaps.com)|2606:d640:0:10::2|:9000... connected.
|
||||
HTTP request sent, awaiting response... 400 Bad Request
|
||||
2025-05-24 18:17:39 ERROR 400: Bad Request.
|
||||
|
||||
Reference in New Issue
Block a user