lint poller
This commit is contained in:
267
poller/main.py
267
poller/main.py
@@ -1,21 +1,18 @@
|
||||
from cohere.core.api_error import ApiError
|
||||
from pgvector.psycopg import register_vector, Bit
|
||||
from psycopg.rows import dict_row
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from urllib.parse import unquote
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
|
||||
import anthropic
|
||||
import boto3
|
||||
import cohere
|
||||
import dotenv
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import minio
|
||||
import numpy as np
|
||||
import os
|
||||
import psycopg
|
||||
import time
|
||||
import logging
|
||||
from cohere.core.api_error import ApiError
|
||||
from pgvector.psycopg import register_vector
|
||||
from psycopg.rows import dict_row
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
from pythonjsonlogger.json import JsonFormatter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,21 +28,21 @@ logger.addHandler(logHandler)
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Required environment variables
|
||||
COHERE_API_KEY = os.environ['COHERE_API_KEY']
|
||||
S3_ACCESS_KEY = os.environ['S3_ACCESS_KEY']
|
||||
S3_SECRET_KEY = os.environ['S3_SECRET_KEY']
|
||||
S3_ENDPOINT = os.environ.get('S3_ENDPOINT', 'https://s3.bigcavemaps.com')
|
||||
S3_REGION = os.environ.get('S3_REGION', 'eu')
|
||||
COHERE_API_KEY = os.environ["COHERE_API_KEY"]
|
||||
S3_ACCESS_KEY = os.environ["S3_ACCESS_KEY"]
|
||||
S3_SECRET_KEY = os.environ["S3_SECRET_KEY"]
|
||||
S3_ENDPOINT = os.environ.get("S3_ENDPOINT", "https://s3.bigcavemaps.com")
|
||||
S3_REGION = os.environ.get("S3_REGION", "eu")
|
||||
|
||||
# Database config
|
||||
DB_HOST = os.environ.get('DB_HOST', 'localhost')
|
||||
DB_PORT = int(os.environ.get('DB_PORT', '5432'))
|
||||
DB_NAME = os.environ.get('DB_NAME', 'cavepediav2_db')
|
||||
DB_USER = os.environ.get('DB_USER', 'cavepediav2_user')
|
||||
DB_PASSWORD = os.environ['DB_PASSWORD']
|
||||
DB_HOST = os.environ.get("DB_HOST", "localhost")
|
||||
DB_PORT = int(os.environ.get("DB_PORT", "5432"))
|
||||
DB_NAME = os.environ.get("DB_NAME", "cavepediav2_db")
|
||||
DB_USER = os.environ.get("DB_USER", "cavepediav2_user")
|
||||
DB_PASSWORD = os.environ["DB_PASSWORD"]
|
||||
|
||||
s3 = boto3.client(
|
||||
's3',
|
||||
"s3",
|
||||
aws_access_key_id=S3_ACCESS_KEY,
|
||||
aws_secret_access_key=S3_SECRET_KEY,
|
||||
endpoint_url=S3_ENDPOINT,
|
||||
@@ -61,6 +58,7 @@ conn = psycopg.connect(
|
||||
row_factory=dict_row,
|
||||
)
|
||||
|
||||
|
||||
## init
|
||||
# events table is created by minio up creation of event destination
|
||||
def create_tables():
|
||||
@@ -94,54 +92,58 @@ def create_tables():
|
||||
embedding vector(1536),
|
||||
UNIQUE(bucket, key)
|
||||
)
|
||||
""")
|
||||
""",
|
||||
)
|
||||
for command in commands:
|
||||
conn.execute(command)
|
||||
conn.commit()
|
||||
register_vector(conn)
|
||||
|
||||
|
||||
def import_files():
|
||||
"""Scan import bucket for any new files; move them to the files bucket and add to db; delete from import bucket"""
|
||||
BUCKET_IMPORT = 'cavepediav2-import'
|
||||
BUCKET_FILES = 'cavepediav2-files'
|
||||
BUCKET_IMPORT = "cavepediav2-import"
|
||||
BUCKET_FILES = "cavepediav2-files"
|
||||
# get new files; add to db, sync to main bucket; delete from import bucket
|
||||
response = s3.list_objects_v2(Bucket=BUCKET_IMPORT)
|
||||
if 'Contents' in response:
|
||||
for obj in response['Contents']:
|
||||
if obj['Key'].endswith('/'):
|
||||
if "Contents" in response:
|
||||
for obj in response["Contents"]:
|
||||
if obj["Key"].endswith("/"):
|
||||
continue
|
||||
s3.copy_object(
|
||||
CopySource={'Bucket': BUCKET_IMPORT, 'Key': obj['Key']},
|
||||
CopySource={"Bucket": BUCKET_IMPORT, "Key": obj["Key"]},
|
||||
Bucket=BUCKET_FILES,
|
||||
Key=obj['Key'],
|
||||
Key=obj["Key"],
|
||||
)
|
||||
conn.execute('INSERT INTO metadata (bucket, key) VALUES(%s, %s);', (BUCKET_FILES, obj['Key']))
|
||||
conn.execute("INSERT INTO metadata (bucket, key) VALUES(%s, %s);", (BUCKET_FILES, obj["Key"]))
|
||||
conn.commit()
|
||||
s3.delete_object(
|
||||
Bucket=BUCKET_IMPORT,
|
||||
Key=obj['Key'],
|
||||
Key=obj["Key"],
|
||||
)
|
||||
|
||||
|
||||
def split_files():
|
||||
"""Split PDFs into single pages for easier processing"""
|
||||
BUCKET_PAGES = 'cavepediav2-pages'
|
||||
BUCKET_PAGES = "cavepediav2-pages"
|
||||
rows = conn.execute("SELECT COUNT(*) FROM metadata WHERE split = false")
|
||||
row = rows.fetchone()
|
||||
logger.info(f'Found {row["count"]} files to split.')
|
||||
rows = conn.execute('SELECT * FROM metadata WHERE split = false')
|
||||
|
||||
assert row is not None
|
||||
logger.info(f"Found {row['count']} files to split.")
|
||||
rows = conn.execute("SELECT * FROM metadata WHERE split = false")
|
||||
|
||||
for row in rows:
|
||||
bucket = row['bucket']
|
||||
key = row['key']
|
||||
bucket = row["bucket"]
|
||||
key = row["key"]
|
||||
|
||||
with conn.cursor() as cur:
|
||||
logger.info(f'SPLITTING bucket: {bucket}, key: {key}')
|
||||
logger.info(f"SPLITTING bucket: {bucket}, key: {key}")
|
||||
|
||||
##### get pdf #####
|
||||
s3.download_file(bucket, key, '/tmp/file.pdf')
|
||||
s3.download_file(bucket, key, "/tmp/file.pdf")
|
||||
|
||||
##### split #####
|
||||
with open('/tmp/file.pdf', 'rb') as f:
|
||||
with open("/tmp/file.pdf", "rb") as f:
|
||||
reader = PdfReader(f)
|
||||
|
||||
for i in range(len(reader.pages)):
|
||||
@@ -151,129 +153,120 @@ def split_files():
|
||||
with io.BytesIO() as bs:
|
||||
writer.write(bs)
|
||||
bs.seek(0)
|
||||
s3.put_object(
|
||||
Bucket=BUCKET_PAGES,
|
||||
Key=f'{key}/page-{i + 1}.pdf',
|
||||
Body=bs.getvalue()
|
||||
)
|
||||
cur.execute('INSERT INTO embeddings (bucket, key, role) VALUES (%s, %s, %s);', (BUCKET_PAGES, f'{key}/page-{i + 1}.pdf', key.split('/')[0]))
|
||||
cur.execute('UPDATE metadata SET SPLIT = true WHERE id = %s', (row['id'],));
|
||||
s3.put_object(Bucket=BUCKET_PAGES, Key=f"{key}/page-{i + 1}.pdf", Body=bs.getvalue())
|
||||
page_key = f"{key}/page-{i + 1}.pdf"
|
||||
role = key.split("/")[0]
|
||||
cur.execute(
|
||||
"INSERT INTO embeddings (bucket, key, role) VALUES (%s, %s, %s);",
|
||||
(BUCKET_PAGES, page_key, role),
|
||||
)
|
||||
cur.execute("UPDATE metadata SET SPLIT = true WHERE id = %s", (row["id"],))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def ocr_create_message(id, bucket, key):
|
||||
"""Create message to send to claude"""
|
||||
url = s3.generate_presigned_url(
|
||||
'get_object',
|
||||
Params={
|
||||
'Bucket': bucket,
|
||||
'Key': unquote(key)
|
||||
},
|
||||
"get_object",
|
||||
Params={"Bucket": bucket, "Key": unquote(key)},
|
||||
)
|
||||
|
||||
message = {
|
||||
'custom_id': f'doc-{id}',
|
||||
'params': {
|
||||
'model': 'claude-haiku-4-5',
|
||||
'max_tokens': 4000,
|
||||
'temperature': 1,
|
||||
'messages': [
|
||||
"custom_id": f"doc-{id}",
|
||||
"params": {
|
||||
"model": "claude-haiku-4-5",
|
||||
"max_tokens": 4000,
|
||||
"temperature": 1,
|
||||
"messages": [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "document", "source": {"type": "url", "url": url}},
|
||||
{
|
||||
'type': 'document',
|
||||
'source': {
|
||||
'type': 'url',
|
||||
'url': url
|
||||
}
|
||||
"type": "text",
|
||||
"text": "Extract all text from this document. "
|
||||
"Do not include any summary or conclusions of your own.",
|
||||
},
|
||||
{
|
||||
'type': 'text',
|
||||
'text': 'Extract all text from this document. Do not include any summary or conclusions of your own.'
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def ocr(bucket, key):
|
||||
"""Gets OCR content of pdfs"""
|
||||
url = s3.generate_presigned_url(
|
||||
'get_object',
|
||||
Params={
|
||||
'Bucket': bucket,
|
||||
'Key': unquote(key)
|
||||
},
|
||||
"get_object",
|
||||
Params={"Bucket": bucket, "Key": unquote(key)},
|
||||
)
|
||||
|
||||
client = anthropic.Anthropic()
|
||||
message = client.messages.create(
|
||||
model='claude-haiku-4-5',
|
||||
model="claude-haiku-4-5",
|
||||
max_tokens=4000,
|
||||
temperature=1,
|
||||
messages=[
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "document", "source": {"type": "url", "url": url}},
|
||||
{
|
||||
'type': 'document',
|
||||
'source': {
|
||||
'type': 'url',
|
||||
'url': url
|
||||
}
|
||||
"type": "text",
|
||||
"text": "Extract all text from this document. "
|
||||
"Do not include any summary or conclusions of your own.",
|
||||
},
|
||||
{
|
||||
'type': 'text',
|
||||
'text': 'Extract all text from this document. Do not include any summary or conclusions of your own.'
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def claude_send_batch(batch):
|
||||
"""Send a batch to claude"""
|
||||
client = anthropic.Anthropic()
|
||||
message_batch = client.messages.batches.create(
|
||||
requests=batch
|
||||
)
|
||||
message_batch = client.messages.batches.create(requests=batch)
|
||||
|
||||
conn.execute('INSERT INTO batches (platform, batch_id, type) VALUES(%s, %s, %s);', ('claude', message_batch.id, 'ocr'))
|
||||
conn.execute(
|
||||
"INSERT INTO batches (platform, batch_id, type) VALUES(%s, %s, %s);", ("claude", message_batch.id, "ocr")
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
logger.info(f'Sent batch_id {message_batch.id} to claude')
|
||||
logger.info(f"Sent batch_id {message_batch.id} to claude")
|
||||
|
||||
|
||||
def check_batches():
|
||||
"""Check batch status"""
|
||||
rows = conn.execute("SELECT COUNT(*) FROM batches WHERE done = false")
|
||||
row = rows.fetchone()
|
||||
logger.info(f'Found {row["count"]} batch(es) to process.')
|
||||
assert row is not None
|
||||
logger.info(f"Found {row['count']} batch(es) to process.")
|
||||
rows = conn.execute("SELECT * FROM batches WHERE done = false")
|
||||
|
||||
|
||||
client = anthropic.Anthropic()
|
||||
for row in rows:
|
||||
message_batch = client.messages.batches.retrieve(
|
||||
row['batch_id'],
|
||||
row["batch_id"],
|
||||
)
|
||||
if message_batch.processing_status == 'ended':
|
||||
if message_batch.processing_status == "ended":
|
||||
results = client.messages.batches.results(
|
||||
row['batch_id'],
|
||||
row["batch_id"],
|
||||
)
|
||||
with conn.cursor() as cur:
|
||||
for result in results:
|
||||
id = int(result.custom_id.split('-')[1])
|
||||
id = int(result.custom_id.split("-")[1])
|
||||
try:
|
||||
content = result.result.message.content[0].text
|
||||
cur.execute('UPDATE embeddings SET content = %s WHERE id = %s;', (content, id))
|
||||
except:
|
||||
cur.execute('UPDATE embeddings SET content = %s WHERE id = %s;', ('ERROR', id))
|
||||
cur.execute('UPDATE batches SET done = true WHERE batch_id = %s;', (row['batch_id'],))
|
||||
content = result.result.message.content[0].text # type: ignore[union-attr]
|
||||
cur.execute("UPDATE embeddings SET content = %s WHERE id = %s;", (content, id))
|
||||
except Exception:
|
||||
cur.execute("UPDATE embeddings SET content = %s WHERE id = %s;", ("ERROR", id))
|
||||
cur.execute("UPDATE batches SET done = true WHERE batch_id = %s;", (row["batch_id"],))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def ocr_main():
|
||||
"""Checks for any non-OCR'd documents and sends them to claude in batches"""
|
||||
## claude 4 sonnet ##
|
||||
@@ -284,37 +277,49 @@ def ocr_main():
|
||||
# get docs where content is null
|
||||
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NULL LIMIT 1000")
|
||||
row = rows.fetchone()
|
||||
logger.info(f'Batching {row["count"]} documents to generate OCR content.')
|
||||
assert row is not None
|
||||
logger.info(f"Batching {row['count']} documents to generate OCR content.")
|
||||
rows = conn.execute("SELECT * FROM embeddings WHERE content IS NULL LIMIT 1000")
|
||||
|
||||
# batch docs; set content = WIP
|
||||
batch = []
|
||||
for row in rows:
|
||||
id = row['id']
|
||||
bucket = row['bucket']
|
||||
key = row['key']
|
||||
id = row["id"]
|
||||
bucket = row["bucket"]
|
||||
key = row["key"]
|
||||
|
||||
logger.info(f'Batching for OCR: {bucket}, key: {key}')
|
||||
logger.info(f"Batching for OCR: {bucket}, key: {key}")
|
||||
|
||||
batch.append(ocr_create_message(id, bucket, key))
|
||||
conn.execute('UPDATE embeddings SET content = %s WHERE id = %s;', ('WIP', id))
|
||||
conn.execute("UPDATE embeddings SET content = %s WHERE id = %s;", ("WIP", id))
|
||||
conn.commit()
|
||||
if len(batch) > 0:
|
||||
claude_send_batch(batch)
|
||||
|
||||
|
||||
def embeddings_main():
|
||||
"""Generate embeddings"""
|
||||
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL")
|
||||
count_query = """
|
||||
SELECT COUNT(*) FROM embeddings
|
||||
WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL
|
||||
"""
|
||||
rows = conn.execute(count_query)
|
||||
row = rows.fetchone()
|
||||
logger.info(f'Batching {row["count"]} documents to generate embeddings.')
|
||||
rows = conn.execute("SELECT id, key, bucket, content FROM embeddings WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL")
|
||||
assert row is not None
|
||||
logger.info(f"Batching {row['count']} documents to generate embeddings.")
|
||||
select_query = """
|
||||
SELECT id, key, bucket, content FROM embeddings
|
||||
WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL
|
||||
"""
|
||||
rows = conn.execute(select_query)
|
||||
|
||||
for row in rows:
|
||||
logger.info(f'Generating embeddings for id: {row["id"]}, bucket: {row["bucket"]}, key: {row["key"]}')
|
||||
embedding = embed(row['content'], 'search_document')
|
||||
conn.execute('UPDATE embeddings SET embedding = %s::vector WHERE id = %s;', (embedding, row['id']))
|
||||
logger.info(f"Generating embeddings for id: {row['id']}, bucket: {row['bucket']}, key: {row['key']}")
|
||||
embedding = embed(row["content"], "search_document")
|
||||
conn.execute("UPDATE embeddings SET embedding = %s::vector WHERE id = %s;", (embedding, row["id"]))
|
||||
conn.commit()
|
||||
|
||||
|
||||
### embeddings
|
||||
def embed(text, input_type):
|
||||
max_retries = 3
|
||||
@@ -322,27 +327,31 @@ def embed(text, input_type):
|
||||
try:
|
||||
resp = co.embed(
|
||||
texts=[text],
|
||||
model='embed-v4.0',
|
||||
model="embed-v4.0",
|
||||
input_type=input_type,
|
||||
embedding_types=['float'],
|
||||
embedding_types=["float"],
|
||||
output_dimension=1536,
|
||||
)
|
||||
return resp.embeddings.float[0]
|
||||
assert resp.embeddings.float_ is not None
|
||||
return resp.embeddings.float_[0]
|
||||
except ApiError as e:
|
||||
if e.status_code == 502 and attempt < max_retries - 1:
|
||||
time.sleep(30 ** attempt) # exponential backoff
|
||||
time.sleep(30**attempt) # exponential backoff
|
||||
continue
|
||||
raise Exception('cohere max retries exceeded')
|
||||
raise Exception("cohere max retries exceeded")
|
||||
|
||||
|
||||
def fix_pages():
|
||||
i = 766
|
||||
while i > 0:
|
||||
conn.execute('UPDATE embeddings SET key = %s WHERE key = %s', (f'public/va/caves-of-virginia.pdf/page-{i}.pdf', f'public/va/caves-of-virginia.pdf/page-{i-1}.pdf'))
|
||||
new_key = f"public/va/caves-of-virginia.pdf/page-{i}.pdf"
|
||||
old_key = f"public/va/caves-of-virginia.pdf/page-{i - 1}.pdf"
|
||||
conn.execute("UPDATE embeddings SET key = %s WHERE key = %s", (new_key, old_key))
|
||||
conn.commit()
|
||||
i -= 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
create_tables()
|
||||
|
||||
while True:
|
||||
@@ -352,5 +361,5 @@ if __name__ == '__main__':
|
||||
ocr_main()
|
||||
embeddings_main()
|
||||
|
||||
logger.info('sleeping 5 minutes')
|
||||
logger.info("sleeping 5 minutes")
|
||||
time.sleep(5 * 60)
|
||||
|
||||
Reference in New Issue
Block a user