Files
cavepediav2/poller/main.py
Paul Walko d7dd7a98fe
Some checks failed
Build and Push Poller Docker Image / lint (push) Failing after 28s
Build and Push Poller Docker Image / build (push) Has been skipped
lint poller
2025-12-12 18:43:36 +01:00

366 lines
11 KiB
Python

import io
import logging
import os
import time
from urllib.parse import unquote
import anthropic
import boto3
import cohere
import dotenv
import psycopg
from cohere.core.api_error import ApiError
from pgvector.psycopg import register_vector
from psycopg.rows import dict_row
from pypdf import PdfReader, PdfWriter
from pythonjsonlogger.json import JsonFormatter
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logHandler = logging.StreamHandler()
formatter = JsonFormatter("{asctime}{message}", style="{")
logHandler.setFormatter(formatter)
logger.addHandler(logHandler)
#####
# Load .env file if it exists (for local dev)
dotenv.load_dotenv()
# Required environment variables
COHERE_API_KEY = os.environ["COHERE_API_KEY"]
S3_ACCESS_KEY = os.environ["S3_ACCESS_KEY"]
S3_SECRET_KEY = os.environ["S3_SECRET_KEY"]
S3_ENDPOINT = os.environ.get("S3_ENDPOINT", "https://s3.bigcavemaps.com")
S3_REGION = os.environ.get("S3_REGION", "eu")
# Database config
DB_HOST = os.environ.get("DB_HOST", "localhost")
DB_PORT = int(os.environ.get("DB_PORT", "5432"))
DB_NAME = os.environ.get("DB_NAME", "cavepediav2_db")
DB_USER = os.environ.get("DB_USER", "cavepediav2_user")
DB_PASSWORD = os.environ["DB_PASSWORD"]
s3 = boto3.client(
"s3",
aws_access_key_id=S3_ACCESS_KEY,
aws_secret_access_key=S3_SECRET_KEY,
endpoint_url=S3_ENDPOINT,
region_name=S3_REGION,
)
co = cohere.ClientV2(api_key=COHERE_API_KEY)
conn = psycopg.connect(
host=DB_HOST,
port=DB_PORT,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
row_factory=dict_row,
)
## init
# events table is created by minio up creation of event destination
def create_tables():
commands = (
"""
CREATE TABLE IF NOT EXISTS metadata (
id SERIAL PRIMARY KEY,
bucket TEXT,
key TEXT,
split BOOLEAN DEFAULT FALSE,
UNIQUE(bucket, key)
)
""",
"""
CREATE TABLE IF NOT EXISTS batches (
id SERIAL PRIMARY KEY,
platform TEXT,
batch_id TEXT,
type TEXT,
done BOOLEAN DEFAULT FALSE
)
""",
"CREATE EXTENSION IF NOT EXISTS vector",
"""
CREATE TABLE IF NOT EXISTS embeddings (
id SERIAL PRIMARY KEY,
role TEXT,
bucket TEXT,
key TEXT,
content TEXT,
embedding vector(1536),
UNIQUE(bucket, key)
)
""",
)
for command in commands:
conn.execute(command)
conn.commit()
register_vector(conn)
def import_files():
"""Scan import bucket for any new files; move them to the files bucket and add to db; delete from import bucket"""
BUCKET_IMPORT = "cavepediav2-import"
BUCKET_FILES = "cavepediav2-files"
# get new files; add to db, sync to main bucket; delete from import bucket
response = s3.list_objects_v2(Bucket=BUCKET_IMPORT)
if "Contents" in response:
for obj in response["Contents"]:
if obj["Key"].endswith("/"):
continue
s3.copy_object(
CopySource={"Bucket": BUCKET_IMPORT, "Key": obj["Key"]},
Bucket=BUCKET_FILES,
Key=obj["Key"],
)
conn.execute("INSERT INTO metadata (bucket, key) VALUES(%s, %s);", (BUCKET_FILES, obj["Key"]))
conn.commit()
s3.delete_object(
Bucket=BUCKET_IMPORT,
Key=obj["Key"],
)
def split_files():
"""Split PDFs into single pages for easier processing"""
BUCKET_PAGES = "cavepediav2-pages"
rows = conn.execute("SELECT COUNT(*) FROM metadata WHERE split = false")
row = rows.fetchone()
assert row is not None
logger.info(f"Found {row['count']} files to split.")
rows = conn.execute("SELECT * FROM metadata WHERE split = false")
for row in rows:
bucket = row["bucket"]
key = row["key"]
with conn.cursor() as cur:
logger.info(f"SPLITTING bucket: {bucket}, key: {key}")
##### get pdf #####
s3.download_file(bucket, key, "/tmp/file.pdf")
##### split #####
with open("/tmp/file.pdf", "rb") as f:
reader = PdfReader(f)
for i in range(len(reader.pages)):
writer = PdfWriter()
writer.add_page(reader.pages[i])
with io.BytesIO() as bs:
writer.write(bs)
bs.seek(0)
s3.put_object(Bucket=BUCKET_PAGES, Key=f"{key}/page-{i + 1}.pdf", Body=bs.getvalue())
page_key = f"{key}/page-{i + 1}.pdf"
role = key.split("/")[0]
cur.execute(
"INSERT INTO embeddings (bucket, key, role) VALUES (%s, %s, %s);",
(BUCKET_PAGES, page_key, role),
)
cur.execute("UPDATE metadata SET SPLIT = true WHERE id = %s", (row["id"],))
conn.commit()
def ocr_create_message(id, bucket, key):
"""Create message to send to claude"""
url = s3.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": unquote(key)},
)
message = {
"custom_id": f"doc-{id}",
"params": {
"model": "claude-haiku-4-5",
"max_tokens": 4000,
"temperature": 1,
"messages": [
{
"role": "user",
"content": [
{"type": "document", "source": {"type": "url", "url": url}},
{
"type": "text",
"text": "Extract all text from this document. "
"Do not include any summary or conclusions of your own.",
},
],
}
],
},
}
return message
def ocr(bucket, key):
"""Gets OCR content of pdfs"""
url = s3.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": unquote(key)},
)
client = anthropic.Anthropic()
message = client.messages.create(
model="claude-haiku-4-5",
max_tokens=4000,
temperature=1,
messages=[
{
"role": "user",
"content": [
{"type": "document", "source": {"type": "url", "url": url}},
{
"type": "text",
"text": "Extract all text from this document. "
"Do not include any summary or conclusions of your own.",
},
],
}
],
)
return message
def claude_send_batch(batch):
"""Send a batch to claude"""
client = anthropic.Anthropic()
message_batch = client.messages.batches.create(requests=batch)
conn.execute(
"INSERT INTO batches (platform, batch_id, type) VALUES(%s, %s, %s);", ("claude", message_batch.id, "ocr")
)
conn.commit()
logger.info(f"Sent batch_id {message_batch.id} to claude")
def check_batches():
"""Check batch status"""
rows = conn.execute("SELECT COUNT(*) FROM batches WHERE done = false")
row = rows.fetchone()
assert row is not None
logger.info(f"Found {row['count']} batch(es) to process.")
rows = conn.execute("SELECT * FROM batches WHERE done = false")
client = anthropic.Anthropic()
for row in rows:
message_batch = client.messages.batches.retrieve(
row["batch_id"],
)
if message_batch.processing_status == "ended":
results = client.messages.batches.results(
row["batch_id"],
)
with conn.cursor() as cur:
for result in results:
id = int(result.custom_id.split("-")[1])
try:
content = result.result.message.content[0].text # type: ignore[union-attr]
cur.execute("UPDATE embeddings SET content = %s WHERE id = %s;", (content, id))
except Exception:
cur.execute("UPDATE embeddings SET content = %s WHERE id = %s;", ("ERROR", id))
cur.execute("UPDATE batches SET done = true WHERE batch_id = %s;", (row["batch_id"],))
conn.commit()
def ocr_main():
"""Checks for any non-OCR'd documents and sends them to claude in batches"""
## claude 4 sonnet ##
# tier 1 limit: 8k tokens/min
# tier 2: enough
# single pdf page: up to 2k tokens
# get docs where content is null
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NULL LIMIT 1000")
row = rows.fetchone()
assert row is not None
logger.info(f"Batching {row['count']} documents to generate OCR content.")
rows = conn.execute("SELECT * FROM embeddings WHERE content IS NULL LIMIT 1000")
# batch docs; set content = WIP
batch = []
for row in rows:
id = row["id"]
bucket = row["bucket"]
key = row["key"]
logger.info(f"Batching for OCR: {bucket}, key: {key}")
batch.append(ocr_create_message(id, bucket, key))
conn.execute("UPDATE embeddings SET content = %s WHERE id = %s;", ("WIP", id))
conn.commit()
if len(batch) > 0:
claude_send_batch(batch)
def embeddings_main():
"""Generate embeddings"""
count_query = """
SELECT COUNT(*) FROM embeddings
WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL
"""
rows = conn.execute(count_query)
row = rows.fetchone()
assert row is not None
logger.info(f"Batching {row['count']} documents to generate embeddings.")
select_query = """
SELECT id, key, bucket, content FROM embeddings
WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL
"""
rows = conn.execute(select_query)
for row in rows:
logger.info(f"Generating embeddings for id: {row['id']}, bucket: {row['bucket']}, key: {row['key']}")
embedding = embed(row["content"], "search_document")
conn.execute("UPDATE embeddings SET embedding = %s::vector WHERE id = %s;", (embedding, row["id"]))
conn.commit()
### embeddings
def embed(text, input_type):
max_retries = 3
for attempt in range(max_retries):
try:
resp = co.embed(
texts=[text],
model="embed-v4.0",
input_type=input_type,
embedding_types=["float"],
output_dimension=1536,
)
assert resp.embeddings.float_ is not None
return resp.embeddings.float_[0]
except ApiError as e:
if e.status_code == 502 and attempt < max_retries - 1:
time.sleep(30**attempt) # exponential backoff
continue
raise Exception("cohere max retries exceeded")
def fix_pages():
i = 766
while i > 0:
new_key = f"public/va/caves-of-virginia.pdf/page-{i}.pdf"
old_key = f"public/va/caves-of-virginia.pdf/page-{i - 1}.pdf"
conn.execute("UPDATE embeddings SET key = %s WHERE key = %s", (new_key, old_key))
conn.commit()
i -= 1
if __name__ == "__main__":
create_tables()
while True:
import_files()
split_files()
check_batches()
ocr_main()
embeddings_main()
logger.info("sleeping 5 minutes")
time.sleep(5 * 60)