add mcp, batching

This commit is contained in:
2025-12-07 04:35:21 +01:00
parent 30f68a9d04
commit d6bc34d138
14 changed files with 1973 additions and 138 deletions

View File

@@ -1,8 +1,10 @@
from cohere.core.api_error import ApiError
from pgvector.psycopg import register_vector, Bit
from psycopg.rows import dict_row
from urllib.parse import unquote
from pypdf import PdfReader, PdfWriter
import anthropic
import boto3
import cohere
import dotenv
import datetime
@@ -25,22 +27,23 @@ logger.addHandler(logHandler)
#####
dotenv.load_dotenv('/home/paul/scripts-private/lech/cavepedia-v2/poller.env')
dotenv.load_dotenv('/home/pew/scripts-private/loser/cavepedia-v2/poller.env')
COHERE_API_KEY = os.getenv('COHERE_API_KEY')
MINIO_ACCESS_KEY = os.getenv('MINIO_ACCESS_KEY')
MINIO_SECRET_KEY = os.getenv('MINIO_SECRET_KEY')
S3_ACCESS_KEY = os.getenv('S3_ACCESS_KEY')
S3_SECRET_KEY = os.getenv('S3_SECRET_KEY')
s3 = minio.Minio(
's3.bigcavemaps.com',
access_key=MINIO_ACCESS_KEY,
secret_key=MINIO_SECRET_KEY,
region='kansascity',
s3 = boto3.client(
's3',
aws_access_key_id=S3_ACCESS_KEY,
aws_secret_access_key=S3_SECRET_KEY,
endpoint_url='https://s3.bigcavemaps.com',
region_name='eu',
)
co = cohere.ClientV2(COHERE_API_KEY)
co = cohere.ClientV2(api_key=COHERE_API_KEY)
conn = psycopg.connect(
host='127.0.0.1',
port=4010,
host='::1',
port=9030,
dbname='cavepediav2_db',
user='cavepediav2_user',
password='cavepediav2_pw',
@@ -51,14 +54,33 @@ conn = psycopg.connect(
# events table is created by minio up creation of event destination
def create_tables():
commands = (
"""
CREATE TABLE IF NOT EXISTS metadata (
id SERIAL PRIMARY KEY,
bucket TEXT,
key TEXT,
split BOOLEAN DEFAULT FALSE,
UNIQUE(bucket, key)
)
""",
"""
CREATE TABLE IF NOT EXISTS batches (
id SERIAL PRIMARY KEY,
platform TEXT,
batch_id TEXT,
type TEXT,
done BOOLEAN DEFAULT FALSE
)
""",
"CREATE EXTENSION IF NOT EXISTS vector",
"""
CREATE TABLE IF NOT EXISTS embeddings (
id SERIAL PRIMARY KEY,
bucket TEXT,
key TEXT,
content TEXT,
embedding vector(1536),
PRIMARY KEY (bucket, key)
UNIQUE(bucket, key)
)
""")
for command in commands:
@@ -66,53 +88,118 @@ def create_tables():
conn.commit()
register_vector(conn)
## splitting
def split_pdfs():
rows = conn.execute('SELECT * FROM events')
def import_files():
"""Scan import bucket for any new files; move them to the files bucket and add to db; delete from import bucket"""
BUCKET_IMPORT = 'cavepediav2-import'
BUCKET_FILES = 'cavepediav2-files'
# get new files; add to db, sync to main bucket; delete from import bucket
response = s3.list_objects_v2(Bucket=BUCKET_IMPORT)
if 'Contents' in response:
for obj in response['Contents']:
if obj['Key'].endswith('/'):
continue
s3.copy_object(
CopySource={'Bucket': BUCKET_IMPORT, 'Key': obj['Key']},
Bucket=BUCKET_FILES,
Key=obj['Key'],
)
conn.execute('INSERT INTO metadata (bucket, key) VALUES(%s, %s);', (BUCKET_FILES, obj['Key']))
conn.commit()
s3.delete_object(
Bucket=BUCKET_IMPORT,
Key=obj['Key'],
)
def split_files():
"""Split PDFs into single pages for easier processing"""
BUCKET_PAGES = 'cavepediav2-pages'
rows = conn.execute("SELECT COUNT(*) FROM metadata WHERE split = false")
row = rows.fetchone()
logger.info(f'Found {row["count"]} files to split.')
rows = conn.execute('SELECT * FROM metadata WHERE split = false')
for row in rows:
bucket = row['bucket']
key = row['key']
with conn.cursor() as cur:
for record in row['value']['Records']:
bucket = record['s3']['bucket']['name']
key = record['s3']['object']['key']
key = unquote(key)
key = key.replace('+',' ')
logger.info(f'SPLITTING bucket: {bucket}, key: {key}')
logger.info(f'SPLITTING bucket: {bucket}, key: {key}')
##### get pdf #####
s3.download_file(bucket, key, '/tmp/file.pdf')
##### get pdf #####
with s3.get_object(bucket, key) as obj:
with open('/tmp/file.pdf', 'wb') as f:
while True:
chunk = obj.read(1024)
if not chunk:
break
f.write(chunk)
##### split #####
with open('/tmp/file.pdf', 'rb') as f:
reader = PdfReader(f)
##### split #####
with open('/tmp/file.pdf', 'rb') as f:
reader = PdfReader(f)
for i in range(len(reader.pages)):
writer = PdfWriter()
writer.add_page(reader.pages[i])
for i in range(len(reader.pages)):
writer = PdfWriter()
writer.add_page(reader.pages[i])
with io.BytesIO() as bs:
writer.write(bs)
bs.seek(0)
s3.put_object(f'{bucket}-pages', f'{key}/page-{i + 1}.pdf', bs, len(bs.getvalue()))
cur.execute('INSERT INTO embeddings (bucket, key) VALUES (%s, %s);', (f'{bucket}-pages', f'{key}/page-{i + 1}.pdf'))
cur.execute('DELETE FROM events WHERE key = %s', (row['key'],))
with io.BytesIO() as bs:
writer.write(bs)
bs.seek(0)
s3.put_object(
Bucket=BUCKET_PAGES,
Key=f'{key}/page-{i + 1}.pdf',
Body=bs.getvalue()
)
cur.execute('INSERT INTO embeddings (bucket, key) VALUES (%s, %s);', (BUCKET_PAGES, f'{key}/page-{i + 1}.pdf'))
cur.execute('UPDATE metadata SET SPLIT = true WHERE id = %s', (row['id'],));
conn.commit()
## processing
def ocr_create_message(id, bucket, key):
"""Create message to send to claude"""
url = s3.generate_presigned_url(
'get_object',
Params={
'Bucket': bucket,
'Key': unquote(key)
},
)
message = {
'custom_id': f'doc-{id}',
'params': {
'model': 'claude-haiku-4-5',
'max_tokens': 4000,
'temperature': 1,
'messages': [
{
'role': 'user',
'content': [
{
'type': 'document',
'source': {
'type': 'url',
'url': url
}
},
{
'type': 'text',
'text': 'Extract all text from this document. Do not include any summary or conclusions of your own.'
}
]
}
]
}
}
return message
def ocr(bucket, key):
url = s3.presigned_get_object(bucket, unquote(key))
"""Gets OCR content of pdfs"""
url = s3.generate_presigned_url(
'get_object',
Params={
'Bucket': bucket,
'Key': unquote(key)
},
)
client = anthropic.Anthropic()
message = client.messages.create(
model='claude-sonnet-4-20250514',
model='claude-haiku-4-5',
max_tokens=4000,
temperature=1,
messages=[
@@ -136,42 +223,115 @@ def ocr(bucket, key):
)
return message
def process_events():
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE embedding IS NULL")
def claude_send_batch(batch):
"""Send a batch to claude"""
client = anthropic.Anthropic()
message_batch = client.messages.batches.create(
requests=batch
)
conn.execute('INSERT INTO batches (platform, batch_id, type) VALUES(%s, %s, %s);', ('claude', message_batch.id, 'ocr'))
conn.commit()
logger.info(f'Sent batch_id {message_batch.id} to claude')
def check_batches():
"""Check batch status"""
rows = conn.execute("SELECT COUNT(*) FROM batches WHERE done = false")
row = rows.fetchone()
logger.info(f'Found {row["count"]} ready to be processed')
rows = conn.execute("SELECT * FROM embeddings WHERE embedding IS NULL")
logger.info(f'Found {row["count"]} batch(es) to process.')
rows = conn.execute("SELECT * FROM batches WHERE done = false")
client = anthropic.Anthropic()
for row in rows:
message_batch = client.messages.batches.retrieve(
row['batch_id'],
)
if message_batch.processing_status == 'ended':
results = client.messages.batches.results(
row['batch_id'],
)
with conn.cursor() as cur:
for result in results:
id = int(result.custom_id.split('-')[1])
try:
content = result.result.message.content[0].text
cur.execute('UPDATE embeddings SET content = %s WHERE id = %s;', (content, id))
except:
cur.execute('UPDATE embeddings SET content = %s WHERE id = %s;', ('ERROR', id))
cur.execute('UPDATE batches SET done = true WHERE batch_id = %s;', (row['batch_id'],))
conn.commit()
def ocr_main():
"""Checks for any non-OCR'd documents and sends them to claude in batches"""
## claude 4 sonnet ##
# tier 1 limit: 8k tokens/min
# tier 2: enough
# single pdf page: up to 2k tokens
# get docs where content is null
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NULL LIMIT 1000")
row = rows.fetchone()
logger.info(f'Batching {row["count"]} documents to generate OCR content.')
rows = conn.execute("SELECT * FROM embeddings WHERE content IS NULL LIMIT 1000")
# batch docs; set content = WIP
batch = []
for row in rows:
id = row['id']
bucket = row['bucket']
key = row['key']
logger.info(f'PROCESSING bucket: {bucket}, key: {key}')
## claude 4 sonnet ##
# tier 1 limit: 8k tokens/min
# tier 2: enough
# single pdf page: up to 2k tokens
try:
ai_ocr = ocr(bucket, key)
text = ai_ocr.content[0].text
logger.info(f'Batching for OCR: {bucket}, key: {key}')
embedding=embed(text, 'search_document')
conn.execute('UPDATE embeddings SET content = %s, embedding = %s::vector WHERE bucket = %s AND key = %s;', (text, embedding, bucket, key))
conn.commit()
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
return True
batch.append(ocr_create_message(id, bucket, key))
conn.execute('UPDATE embeddings SET content = %s WHERE id = %s;', ('WIP', id))
conn.commit()
if len(batch) > 0:
claude_send_batch(batch)
def embeddings_main():
"""Generate embeddings"""
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL")
row = rows.fetchone()
logger.info(f'Batching {row["count"]} documents to generate embeddings.')
rows = conn.execute("SELECT id, key, bucket, content FROM embeddings WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL")
for row in rows:
logger.info(f'Generating embeddings for id: {row["id"]}, bucket: {row["bucket"]}, key: {row["key"]}')
embedding = embed(row['content'], 'search_document')
conn.execute('UPDATE embeddings SET embedding = %s::vector WHERE id = %s;', (embedding, row['id']))
conn.commit()
# try:
# ai_ocr = ocr(bucket, key)
# text = ai_ocr.content[0].text
#
# embedding=embed(text, 'search_document')
# conn.execute('UPDATE embeddings SET content = %s, embedding = %s::vector WHERE bucket = %s AND key = %s;', (text, embedding, bucket, key))
# conn.commit()
# except Exception as e:
# logger.error(f"An unexpected error occurred: {e}")
# return True
### embeddings
def embed(text, input_type):
resp = co.embed(
texts=[text],
model='embed-v4.0',
input_type=input_type,
embedding_types=['float'],
)
return resp.embeddings.float[0]
max_retries = 3
for attempt in range(max_retries):
try:
resp = co.embed(
texts=[text],
model='embed-v4.0',
input_type=input_type,
embedding_types=['float'],
output_dimension=1536,
)
return resp.embeddings.float[0]
except ApiError as e:
if e.status_code == 502 and attempt < max_retries - 1:
time.sleep(30 ** attempt) # exponential backoff
continue
raise Exception('cohere max retries exceeded')
def fix_pages():
i = 766
@@ -183,14 +343,13 @@ def fix_pages():
if __name__ == '__main__':
create_tables()
while True:
BACKOFF = False
import_files()
split_files()
check_batches()
ocr_main()
embeddings_main()
split_pdfs()
BACKOFF = process_events()
if BACKOFF:
logger.info('backoff detected, sleeping an extra 5 minutes')
time.sleep(5 * 60)
logger.info('sleeping 5 minutes')
time.sleep(5 * 60)