lint poller
Some checks failed
Build and Push Poller Docker Image / lint (push) Failing after 28s
Build and Push Poller Docker Image / build (push) Has been skipped

This commit is contained in:
2025-12-12 18:43:36 +01:00
parent ae73ecf68b
commit d7dd7a98fe
6 changed files with 222 additions and 330 deletions

View File

@@ -1,21 +1,18 @@
from cohere.core.api_error import ApiError
from pgvector.psycopg import register_vector, Bit
from psycopg.rows import dict_row
import io
import logging
import os
import time
from urllib.parse import unquote
from pypdf import PdfReader, PdfWriter
import anthropic
import boto3
import cohere
import dotenv
import datetime
import io
import json
import minio
import numpy as np
import os
import psycopg
import time
import logging
from cohere.core.api_error import ApiError
from pgvector.psycopg import register_vector
from psycopg.rows import dict_row
from pypdf import PdfReader, PdfWriter
from pythonjsonlogger.json import JsonFormatter
logger = logging.getLogger(__name__)
@@ -31,21 +28,21 @@ logger.addHandler(logHandler)
dotenv.load_dotenv()
# Required environment variables
COHERE_API_KEY = os.environ['COHERE_API_KEY']
S3_ACCESS_KEY = os.environ['S3_ACCESS_KEY']
S3_SECRET_KEY = os.environ['S3_SECRET_KEY']
S3_ENDPOINT = os.environ.get('S3_ENDPOINT', 'https://s3.bigcavemaps.com')
S3_REGION = os.environ.get('S3_REGION', 'eu')
COHERE_API_KEY = os.environ["COHERE_API_KEY"]
S3_ACCESS_KEY = os.environ["S3_ACCESS_KEY"]
S3_SECRET_KEY = os.environ["S3_SECRET_KEY"]
S3_ENDPOINT = os.environ.get("S3_ENDPOINT", "https://s3.bigcavemaps.com")
S3_REGION = os.environ.get("S3_REGION", "eu")
# Database config
DB_HOST = os.environ.get('DB_HOST', 'localhost')
DB_PORT = int(os.environ.get('DB_PORT', '5432'))
DB_NAME = os.environ.get('DB_NAME', 'cavepediav2_db')
DB_USER = os.environ.get('DB_USER', 'cavepediav2_user')
DB_PASSWORD = os.environ['DB_PASSWORD']
DB_HOST = os.environ.get("DB_HOST", "localhost")
DB_PORT = int(os.environ.get("DB_PORT", "5432"))
DB_NAME = os.environ.get("DB_NAME", "cavepediav2_db")
DB_USER = os.environ.get("DB_USER", "cavepediav2_user")
DB_PASSWORD = os.environ["DB_PASSWORD"]
s3 = boto3.client(
's3',
"s3",
aws_access_key_id=S3_ACCESS_KEY,
aws_secret_access_key=S3_SECRET_KEY,
endpoint_url=S3_ENDPOINT,
@@ -61,6 +58,7 @@ conn = psycopg.connect(
row_factory=dict_row,
)
## init
# events table is created by minio up creation of event destination
def create_tables():
@@ -94,54 +92,58 @@ def create_tables():
embedding vector(1536),
UNIQUE(bucket, key)
)
""")
""",
)
for command in commands:
conn.execute(command)
conn.commit()
register_vector(conn)
def import_files():
"""Scan import bucket for any new files; move them to the files bucket and add to db; delete from import bucket"""
BUCKET_IMPORT = 'cavepediav2-import'
BUCKET_FILES = 'cavepediav2-files'
BUCKET_IMPORT = "cavepediav2-import"
BUCKET_FILES = "cavepediav2-files"
# get new files; add to db, sync to main bucket; delete from import bucket
response = s3.list_objects_v2(Bucket=BUCKET_IMPORT)
if 'Contents' in response:
for obj in response['Contents']:
if obj['Key'].endswith('/'):
if "Contents" in response:
for obj in response["Contents"]:
if obj["Key"].endswith("/"):
continue
s3.copy_object(
CopySource={'Bucket': BUCKET_IMPORT, 'Key': obj['Key']},
CopySource={"Bucket": BUCKET_IMPORT, "Key": obj["Key"]},
Bucket=BUCKET_FILES,
Key=obj['Key'],
Key=obj["Key"],
)
conn.execute('INSERT INTO metadata (bucket, key) VALUES(%s, %s);', (BUCKET_FILES, obj['Key']))
conn.execute("INSERT INTO metadata (bucket, key) VALUES(%s, %s);", (BUCKET_FILES, obj["Key"]))
conn.commit()
s3.delete_object(
Bucket=BUCKET_IMPORT,
Key=obj['Key'],
Key=obj["Key"],
)
def split_files():
"""Split PDFs into single pages for easier processing"""
BUCKET_PAGES = 'cavepediav2-pages'
BUCKET_PAGES = "cavepediav2-pages"
rows = conn.execute("SELECT COUNT(*) FROM metadata WHERE split = false")
row = rows.fetchone()
logger.info(f'Found {row["count"]} files to split.')
rows = conn.execute('SELECT * FROM metadata WHERE split = false')
assert row is not None
logger.info(f"Found {row['count']} files to split.")
rows = conn.execute("SELECT * FROM metadata WHERE split = false")
for row in rows:
bucket = row['bucket']
key = row['key']
bucket = row["bucket"]
key = row["key"]
with conn.cursor() as cur:
logger.info(f'SPLITTING bucket: {bucket}, key: {key}')
logger.info(f"SPLITTING bucket: {bucket}, key: {key}")
##### get pdf #####
s3.download_file(bucket, key, '/tmp/file.pdf')
s3.download_file(bucket, key, "/tmp/file.pdf")
##### split #####
with open('/tmp/file.pdf', 'rb') as f:
with open("/tmp/file.pdf", "rb") as f:
reader = PdfReader(f)
for i in range(len(reader.pages)):
@@ -151,129 +153,120 @@ def split_files():
with io.BytesIO() as bs:
writer.write(bs)
bs.seek(0)
s3.put_object(
Bucket=BUCKET_PAGES,
Key=f'{key}/page-{i + 1}.pdf',
Body=bs.getvalue()
)
cur.execute('INSERT INTO embeddings (bucket, key, role) VALUES (%s, %s, %s);', (BUCKET_PAGES, f'{key}/page-{i + 1}.pdf', key.split('/')[0]))
cur.execute('UPDATE metadata SET SPLIT = true WHERE id = %s', (row['id'],));
s3.put_object(Bucket=BUCKET_PAGES, Key=f"{key}/page-{i + 1}.pdf", Body=bs.getvalue())
page_key = f"{key}/page-{i + 1}.pdf"
role = key.split("/")[0]
cur.execute(
"INSERT INTO embeddings (bucket, key, role) VALUES (%s, %s, %s);",
(BUCKET_PAGES, page_key, role),
)
cur.execute("UPDATE metadata SET SPLIT = true WHERE id = %s", (row["id"],))
conn.commit()
def ocr_create_message(id, bucket, key):
"""Create message to send to claude"""
url = s3.generate_presigned_url(
'get_object',
Params={
'Bucket': bucket,
'Key': unquote(key)
},
"get_object",
Params={"Bucket": bucket, "Key": unquote(key)},
)
message = {
'custom_id': f'doc-{id}',
'params': {
'model': 'claude-haiku-4-5',
'max_tokens': 4000,
'temperature': 1,
'messages': [
"custom_id": f"doc-{id}",
"params": {
"model": "claude-haiku-4-5",
"max_tokens": 4000,
"temperature": 1,
"messages": [
{
'role': 'user',
'content': [
"role": "user",
"content": [
{"type": "document", "source": {"type": "url", "url": url}},
{
'type': 'document',
'source': {
'type': 'url',
'url': url
}
"type": "text",
"text": "Extract all text from this document. "
"Do not include any summary or conclusions of your own.",
},
{
'type': 'text',
'text': 'Extract all text from this document. Do not include any summary or conclusions of your own.'
}
]
],
}
]
}
],
},
}
return message
def ocr(bucket, key):
"""Gets OCR content of pdfs"""
url = s3.generate_presigned_url(
'get_object',
Params={
'Bucket': bucket,
'Key': unquote(key)
},
"get_object",
Params={"Bucket": bucket, "Key": unquote(key)},
)
client = anthropic.Anthropic()
message = client.messages.create(
model='claude-haiku-4-5',
model="claude-haiku-4-5",
max_tokens=4000,
temperature=1,
messages=[
{
'role': 'user',
'content': [
"role": "user",
"content": [
{"type": "document", "source": {"type": "url", "url": url}},
{
'type': 'document',
'source': {
'type': 'url',
'url': url
}
"type": "text",
"text": "Extract all text from this document. "
"Do not include any summary or conclusions of your own.",
},
{
'type': 'text',
'text': 'Extract all text from this document. Do not include any summary or conclusions of your own.'
}
]
],
}
],
)
return message
def claude_send_batch(batch):
"""Send a batch to claude"""
client = anthropic.Anthropic()
message_batch = client.messages.batches.create(
requests=batch
)
message_batch = client.messages.batches.create(requests=batch)
conn.execute('INSERT INTO batches (platform, batch_id, type) VALUES(%s, %s, %s);', ('claude', message_batch.id, 'ocr'))
conn.execute(
"INSERT INTO batches (platform, batch_id, type) VALUES(%s, %s, %s);", ("claude", message_batch.id, "ocr")
)
conn.commit()
logger.info(f'Sent batch_id {message_batch.id} to claude')
logger.info(f"Sent batch_id {message_batch.id} to claude")
def check_batches():
"""Check batch status"""
rows = conn.execute("SELECT COUNT(*) FROM batches WHERE done = false")
row = rows.fetchone()
logger.info(f'Found {row["count"]} batch(es) to process.')
assert row is not None
logger.info(f"Found {row['count']} batch(es) to process.")
rows = conn.execute("SELECT * FROM batches WHERE done = false")
client = anthropic.Anthropic()
for row in rows:
message_batch = client.messages.batches.retrieve(
row['batch_id'],
row["batch_id"],
)
if message_batch.processing_status == 'ended':
if message_batch.processing_status == "ended":
results = client.messages.batches.results(
row['batch_id'],
row["batch_id"],
)
with conn.cursor() as cur:
for result in results:
id = int(result.custom_id.split('-')[1])
id = int(result.custom_id.split("-")[1])
try:
content = result.result.message.content[0].text
cur.execute('UPDATE embeddings SET content = %s WHERE id = %s;', (content, id))
except:
cur.execute('UPDATE embeddings SET content = %s WHERE id = %s;', ('ERROR', id))
cur.execute('UPDATE batches SET done = true WHERE batch_id = %s;', (row['batch_id'],))
content = result.result.message.content[0].text # type: ignore[union-attr]
cur.execute("UPDATE embeddings SET content = %s WHERE id = %s;", (content, id))
except Exception:
cur.execute("UPDATE embeddings SET content = %s WHERE id = %s;", ("ERROR", id))
cur.execute("UPDATE batches SET done = true WHERE batch_id = %s;", (row["batch_id"],))
conn.commit()
def ocr_main():
"""Checks for any non-OCR'd documents and sends them to claude in batches"""
## claude 4 sonnet ##
@@ -284,37 +277,49 @@ def ocr_main():
# get docs where content is null
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NULL LIMIT 1000")
row = rows.fetchone()
logger.info(f'Batching {row["count"]} documents to generate OCR content.')
assert row is not None
logger.info(f"Batching {row['count']} documents to generate OCR content.")
rows = conn.execute("SELECT * FROM embeddings WHERE content IS NULL LIMIT 1000")
# batch docs; set content = WIP
batch = []
for row in rows:
id = row['id']
bucket = row['bucket']
key = row['key']
id = row["id"]
bucket = row["bucket"]
key = row["key"]
logger.info(f'Batching for OCR: {bucket}, key: {key}')
logger.info(f"Batching for OCR: {bucket}, key: {key}")
batch.append(ocr_create_message(id, bucket, key))
conn.execute('UPDATE embeddings SET content = %s WHERE id = %s;', ('WIP', id))
conn.execute("UPDATE embeddings SET content = %s WHERE id = %s;", ("WIP", id))
conn.commit()
if len(batch) > 0:
claude_send_batch(batch)
def embeddings_main():
"""Generate embeddings"""
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL")
count_query = """
SELECT COUNT(*) FROM embeddings
WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL
"""
rows = conn.execute(count_query)
row = rows.fetchone()
logger.info(f'Batching {row["count"]} documents to generate embeddings.')
rows = conn.execute("SELECT id, key, bucket, content FROM embeddings WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL")
assert row is not None
logger.info(f"Batching {row['count']} documents to generate embeddings.")
select_query = """
SELECT id, key, bucket, content FROM embeddings
WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL
"""
rows = conn.execute(select_query)
for row in rows:
logger.info(f'Generating embeddings for id: {row["id"]}, bucket: {row["bucket"]}, key: {row["key"]}')
embedding = embed(row['content'], 'search_document')
conn.execute('UPDATE embeddings SET embedding = %s::vector WHERE id = %s;', (embedding, row['id']))
logger.info(f"Generating embeddings for id: {row['id']}, bucket: {row['bucket']}, key: {row['key']}")
embedding = embed(row["content"], "search_document")
conn.execute("UPDATE embeddings SET embedding = %s::vector WHERE id = %s;", (embedding, row["id"]))
conn.commit()
### embeddings
def embed(text, input_type):
max_retries = 3
@@ -322,27 +327,31 @@ def embed(text, input_type):
try:
resp = co.embed(
texts=[text],
model='embed-v4.0',
model="embed-v4.0",
input_type=input_type,
embedding_types=['float'],
embedding_types=["float"],
output_dimension=1536,
)
return resp.embeddings.float[0]
assert resp.embeddings.float_ is not None
return resp.embeddings.float_[0]
except ApiError as e:
if e.status_code == 502 and attempt < max_retries - 1:
time.sleep(30 ** attempt) # exponential backoff
time.sleep(30**attempt) # exponential backoff
continue
raise Exception('cohere max retries exceeded')
raise Exception("cohere max retries exceeded")
def fix_pages():
i = 766
while i > 0:
conn.execute('UPDATE embeddings SET key = %s WHERE key = %s', (f'public/va/caves-of-virginia.pdf/page-{i}.pdf', f'public/va/caves-of-virginia.pdf/page-{i-1}.pdf'))
new_key = f"public/va/caves-of-virginia.pdf/page-{i}.pdf"
old_key = f"public/va/caves-of-virginia.pdf/page-{i - 1}.pdf"
conn.execute("UPDATE embeddings SET key = %s WHERE key = %s", (new_key, old_key))
conn.commit()
i -= 1
if __name__ == '__main__':
if __name__ == "__main__":
create_tables()
while True:
@@ -352,5 +361,5 @@ if __name__ == '__main__':
ocr_main()
embeddings_main()
logger.info('sleeping 5 minutes')
logger.info("sleeping 5 minutes")
time.sleep(5 * 60)