Files
cavepediav2/poller/main.py
Paul Walko 4ce3eb0abd
Some checks failed
Build and Push Agent Docker Image / build (push) Failing after 1m5s
Build and Push Poller Docker Image / build (push) Failing after 43s
Build and Push Web Docker Image / build (push) Successful in 3m35s
poller docker
2025-12-12 17:52:44 +01:00

357 lines
11 KiB
Python

from cohere.core.api_error import ApiError
from pgvector.psycopg import register_vector, Bit
from psycopg.rows import dict_row
from urllib.parse import unquote
from pypdf import PdfReader, PdfWriter
import anthropic
import boto3
import cohere
import dotenv
import datetime
import io
import json
import minio
import numpy as np
import os
import psycopg
import time
import logging
from pythonjsonlogger.json import JsonFormatter
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logHandler = logging.StreamHandler()
formatter = JsonFormatter("{asctime}{message}", style="{")
logHandler.setFormatter(formatter)
logger.addHandler(logHandler)
#####
# Load .env file if it exists (for local dev)
dotenv.load_dotenv()
# Required environment variables
COHERE_API_KEY = os.environ['COHERE_API_KEY']
S3_ACCESS_KEY = os.environ['S3_ACCESS_KEY']
S3_SECRET_KEY = os.environ['S3_SECRET_KEY']
S3_ENDPOINT = os.environ.get('S3_ENDPOINT', 'https://s3.bigcavemaps.com')
S3_REGION = os.environ.get('S3_REGION', 'eu')
# Database config
DB_HOST = os.environ.get('DB_HOST', 'localhost')
DB_PORT = int(os.environ.get('DB_PORT', '5432'))
DB_NAME = os.environ.get('DB_NAME', 'cavepediav2_db')
DB_USER = os.environ.get('DB_USER', 'cavepediav2_user')
DB_PASSWORD = os.environ['DB_PASSWORD']
s3 = boto3.client(
's3',
aws_access_key_id=S3_ACCESS_KEY,
aws_secret_access_key=S3_SECRET_KEY,
endpoint_url=S3_ENDPOINT,
region_name=S3_REGION,
)
co = cohere.ClientV2(api_key=COHERE_API_KEY)
conn = psycopg.connect(
host=DB_HOST,
port=DB_PORT,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
row_factory=dict_row,
)
## init
# events table is created by minio up creation of event destination
def create_tables():
commands = (
"""
CREATE TABLE IF NOT EXISTS metadata (
id SERIAL PRIMARY KEY,
bucket TEXT,
key TEXT,
split BOOLEAN DEFAULT FALSE,
UNIQUE(bucket, key)
)
""",
"""
CREATE TABLE IF NOT EXISTS batches (
id SERIAL PRIMARY KEY,
platform TEXT,
batch_id TEXT,
type TEXT,
done BOOLEAN DEFAULT FALSE
)
""",
"CREATE EXTENSION IF NOT EXISTS vector",
"""
CREATE TABLE IF NOT EXISTS embeddings (
id SERIAL PRIMARY KEY,
role TEXT,
bucket TEXT,
key TEXT,
content TEXT,
embedding vector(1536),
UNIQUE(bucket, key)
)
""")
for command in commands:
conn.execute(command)
conn.commit()
register_vector(conn)
def import_files():
"""Scan import bucket for any new files; move them to the files bucket and add to db; delete from import bucket"""
BUCKET_IMPORT = 'cavepediav2-import'
BUCKET_FILES = 'cavepediav2-files'
# get new files; add to db, sync to main bucket; delete from import bucket
response = s3.list_objects_v2(Bucket=BUCKET_IMPORT)
if 'Contents' in response:
for obj in response['Contents']:
if obj['Key'].endswith('/'):
continue
s3.copy_object(
CopySource={'Bucket': BUCKET_IMPORT, 'Key': obj['Key']},
Bucket=BUCKET_FILES,
Key=obj['Key'],
)
conn.execute('INSERT INTO metadata (bucket, key) VALUES(%s, %s);', (BUCKET_FILES, obj['Key']))
conn.commit()
s3.delete_object(
Bucket=BUCKET_IMPORT,
Key=obj['Key'],
)
def split_files():
"""Split PDFs into single pages for easier processing"""
BUCKET_PAGES = 'cavepediav2-pages'
rows = conn.execute("SELECT COUNT(*) FROM metadata WHERE split = false")
row = rows.fetchone()
logger.info(f'Found {row["count"]} files to split.')
rows = conn.execute('SELECT * FROM metadata WHERE split = false')
for row in rows:
bucket = row['bucket']
key = row['key']
with conn.cursor() as cur:
logger.info(f'SPLITTING bucket: {bucket}, key: {key}')
##### get pdf #####
s3.download_file(bucket, key, '/tmp/file.pdf')
##### split #####
with open('/tmp/file.pdf', 'rb') as f:
reader = PdfReader(f)
for i in range(len(reader.pages)):
writer = PdfWriter()
writer.add_page(reader.pages[i])
with io.BytesIO() as bs:
writer.write(bs)
bs.seek(0)
s3.put_object(
Bucket=BUCKET_PAGES,
Key=f'{key}/page-{i + 1}.pdf',
Body=bs.getvalue()
)
cur.execute('INSERT INTO embeddings (bucket, key, role) VALUES (%s, %s, %s);', (BUCKET_PAGES, f'{key}/page-{i + 1}.pdf', key.split('/')[0]))
cur.execute('UPDATE metadata SET SPLIT = true WHERE id = %s', (row['id'],));
conn.commit()
def ocr_create_message(id, bucket, key):
"""Create message to send to claude"""
url = s3.generate_presigned_url(
'get_object',
Params={
'Bucket': bucket,
'Key': unquote(key)
},
)
message = {
'custom_id': f'doc-{id}',
'params': {
'model': 'claude-haiku-4-5',
'max_tokens': 4000,
'temperature': 1,
'messages': [
{
'role': 'user',
'content': [
{
'type': 'document',
'source': {
'type': 'url',
'url': url
}
},
{
'type': 'text',
'text': 'Extract all text from this document. Do not include any summary or conclusions of your own.'
}
]
}
]
}
}
return message
def ocr(bucket, key):
"""Gets OCR content of pdfs"""
url = s3.generate_presigned_url(
'get_object',
Params={
'Bucket': bucket,
'Key': unquote(key)
},
)
client = anthropic.Anthropic()
message = client.messages.create(
model='claude-haiku-4-5',
max_tokens=4000,
temperature=1,
messages=[
{
'role': 'user',
'content': [
{
'type': 'document',
'source': {
'type': 'url',
'url': url
}
},
{
'type': 'text',
'text': 'Extract all text from this document. Do not include any summary or conclusions of your own.'
}
]
}
],
)
return message
def claude_send_batch(batch):
"""Send a batch to claude"""
client = anthropic.Anthropic()
message_batch = client.messages.batches.create(
requests=batch
)
conn.execute('INSERT INTO batches (platform, batch_id, type) VALUES(%s, %s, %s);', ('claude', message_batch.id, 'ocr'))
conn.commit()
logger.info(f'Sent batch_id {message_batch.id} to claude')
def check_batches():
"""Check batch status"""
rows = conn.execute("SELECT COUNT(*) FROM batches WHERE done = false")
row = rows.fetchone()
logger.info(f'Found {row["count"]} batch(es) to process.')
rows = conn.execute("SELECT * FROM batches WHERE done = false")
client = anthropic.Anthropic()
for row in rows:
message_batch = client.messages.batches.retrieve(
row['batch_id'],
)
if message_batch.processing_status == 'ended':
results = client.messages.batches.results(
row['batch_id'],
)
with conn.cursor() as cur:
for result in results:
id = int(result.custom_id.split('-')[1])
try:
content = result.result.message.content[0].text
cur.execute('UPDATE embeddings SET content = %s WHERE id = %s;', (content, id))
except:
cur.execute('UPDATE embeddings SET content = %s WHERE id = %s;', ('ERROR', id))
cur.execute('UPDATE batches SET done = true WHERE batch_id = %s;', (row['batch_id'],))
conn.commit()
def ocr_main():
"""Checks for any non-OCR'd documents and sends them to claude in batches"""
## claude 4 sonnet ##
# tier 1 limit: 8k tokens/min
# tier 2: enough
# single pdf page: up to 2k tokens
# get docs where content is null
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NULL LIMIT 1000")
row = rows.fetchone()
logger.info(f'Batching {row["count"]} documents to generate OCR content.')
rows = conn.execute("SELECT * FROM embeddings WHERE content IS NULL LIMIT 1000")
# batch docs; set content = WIP
batch = []
for row in rows:
id = row['id']
bucket = row['bucket']
key = row['key']
logger.info(f'Batching for OCR: {bucket}, key: {key}')
batch.append(ocr_create_message(id, bucket, key))
conn.execute('UPDATE embeddings SET content = %s WHERE id = %s;', ('WIP', id))
conn.commit()
if len(batch) > 0:
claude_send_batch(batch)
def embeddings_main():
"""Generate embeddings"""
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL")
row = rows.fetchone()
logger.info(f'Batching {row["count"]} documents to generate embeddings.')
rows = conn.execute("SELECT id, key, bucket, content FROM embeddings WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL")
for row in rows:
logger.info(f'Generating embeddings for id: {row["id"]}, bucket: {row["bucket"]}, key: {row["key"]}')
embedding = embed(row['content'], 'search_document')
conn.execute('UPDATE embeddings SET embedding = %s::vector WHERE id = %s;', (embedding, row['id']))
conn.commit()
### embeddings
def embed(text, input_type):
max_retries = 3
for attempt in range(max_retries):
try:
resp = co.embed(
texts=[text],
model='embed-v4.0',
input_type=input_type,
embedding_types=['float'],
output_dimension=1536,
)
return resp.embeddings.float[0]
except ApiError as e:
if e.status_code == 502 and attempt < max_retries - 1:
time.sleep(30 ** attempt) # exponential backoff
continue
raise Exception('cohere max retries exceeded')
def fix_pages():
i = 766
while i > 0:
conn.execute('UPDATE embeddings SET key = %s WHERE key = %s', (f'public/va/caves-of-virginia.pdf/page-{i}.pdf', f'public/va/caves-of-virginia.pdf/page-{i-1}.pdf'))
conn.commit()
i -= 1
if __name__ == '__main__':
create_tables()
while True:
import_files()
split_files()
check_batches()
ocr_main()
embeddings_main()
logger.info('sleeping 5 minutes')
time.sleep(5 * 60)