add mcp, batching
This commit is contained in:
@@ -7,11 +7,11 @@ up () {
|
||||
--detach \
|
||||
--name cp2-pg \
|
||||
--restart unless-stopped \
|
||||
--env-file $HOME/scripts-private/lech/cavepedia-v2/cp2-pg.env \
|
||||
--volume /mammoth/cp2/cp2-pg/data:/var/lib/postgresql/data:rw \
|
||||
--publish 127.0.0.1:4010:5432 \
|
||||
--env-file $HOME/scripts-private/loser/cavepedia-v2/cp2-pg.env \
|
||||
--volume /texas/cp2/cp2-pg/18/data:/var/lib/postgresql/18/docker:rw \
|
||||
--publish [::1]:9030:5432 \
|
||||
--network pew-net \
|
||||
pgvector/pgvector:pg17
|
||||
pgvector/pgvector:pg18
|
||||
}
|
||||
|
||||
down () {
|
||||
1
mcp/.python-version
Normal file
1
mcp/.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.11
|
||||
5
mcp/README.md
Normal file
5
mcp/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# cavepedia-v2 mcp
|
||||
|
||||
# todo
|
||||
- signout endpoint
|
||||
- auth
|
||||
BIN
mcp/__pycache__/main.cpython-311.pyc
Normal file
BIN
mcp/__pycache__/main.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mcp/__pycache__/server.cpython-311.pyc
Normal file
BIN
mcp/__pycache__/server.cpython-311.pyc
Normal file
Binary file not shown.
13
mcp/pyproject.toml
Normal file
13
mcp/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "cavepediav2-mcp"
|
||||
version = "0.1.0"
|
||||
description = "MCP for cavepediav2"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"asyncio>=4.0.0",
|
||||
"cohere>=5.20.0",
|
||||
"dotenv>=0.9.9",
|
||||
"fastmcp>=2.13.3",
|
||||
"psycopg[binary]>=3.3.2",
|
||||
]
|
||||
48
mcp/search.py
Normal file
48
mcp/search.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from pgvector.psycopg import register_vector, Bit
|
||||
from psycopg.rows import dict_row
|
||||
from urllib.parse import unquote
|
||||
import anthropic
|
||||
import cohere
|
||||
import dotenv
|
||||
import datetime
|
||||
import json
|
||||
import minio
|
||||
import numpy as np
|
||||
import os
|
||||
import psycopg
|
||||
import time
|
||||
|
||||
dotenv.load_dotenv('/home/paul/scripts-private/lech/cavepedia-v2/poller.env')
|
||||
|
||||
COHERE_API_KEY = os.getenv('COHERE_API_KEY')
|
||||
|
||||
co = cohere.ClientV2(COHERE_API_KEY)
|
||||
conn = psycopg.connect(
|
||||
host='127.0.0.1',
|
||||
port=4010,
|
||||
dbname='cavepediav2_db',
|
||||
user='cavepediav2_user',
|
||||
password='cavepediav2_pw',
|
||||
row_factory=dict_row,
|
||||
)
|
||||
|
||||
def embed(text, input_type):
|
||||
resp = co.embed(
|
||||
texts=[text],
|
||||
model='embed-v4.0',
|
||||
input_type=input_type,
|
||||
embedding_types=['float'],
|
||||
)
|
||||
return resp.embeddings.float[0]
|
||||
|
||||
def search():
|
||||
query = 'links trip with not more than 2 people'
|
||||
query_embedding = embed(query, 'search_query')
|
||||
|
||||
rows = conn.execute('SELECT * FROM embeddings ORDER BY embedding <=> %s::vector LIMIT 5', (query_embedding,)).fetchall()
|
||||
for row in rows:
|
||||
print(row['bucket'])
|
||||
print(row['key'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
search()
|
||||
79
mcp/server.py
Normal file
79
mcp/server.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.auth.providers.auth0 import Auth0Provider
|
||||
from psycopg.rows import dict_row
|
||||
import cohere
|
||||
import dotenv
|
||||
import psycopg
|
||||
import os
|
||||
|
||||
dotenv.load_dotenv('/home/pew/scripts-private/loser/cavepedia-v2/poller.env')
|
||||
|
||||
COHERE_API_KEY = os.getenv('COHERE_API_KEY')
|
||||
|
||||
co = cohere.ClientV2(COHERE_API_KEY)
|
||||
conn = psycopg.connect(
|
||||
host='::1',
|
||||
port=9030,
|
||||
dbname='cavepediav2_db',
|
||||
user='cavepediav2_user',
|
||||
password='cavepediav2_pw',
|
||||
row_factory=dict_row,
|
||||
)
|
||||
|
||||
|
||||
# The Auth0Provider utilizes Auth0 OIDC configuration
|
||||
auth_provider = Auth0Provider(
|
||||
config_url="https://dev-jao4so0av61ny4mr.us.auth0.com/.well-known/openid-configuration",
|
||||
client_id="oONcxma5PNFwYLhrDC4o0PUuAmqDekzM",
|
||||
client_secret="4Z7Wl12ALEtDmNAoERQe7lK2YD9x6jz7H25FiMxRp518dnag-IS2NLLScnmbe4-b",
|
||||
audience="https://dev-jao4so0av61ny4mr.us.auth0.com/me/",
|
||||
base_url="https://mcp.caving.dev",
|
||||
# redirect_path="/auth/callback" # Default value, customize if needed
|
||||
)
|
||||
|
||||
mcp = FastMCP("Cavepedia MCP")
|
||||
|
||||
def embed(text, input_type):
|
||||
resp = co.embed(
|
||||
texts=[text],
|
||||
model='embed-v4.0',
|
||||
input_type=input_type,
|
||||
embedding_types=['float'],
|
||||
)
|
||||
return resp.embeddings.float[0]
|
||||
|
||||
def search(query) -> list[dict]:
|
||||
query_embedding = embed(query, 'search_query')
|
||||
|
||||
rows = conn.execute('SELECT * FROM embeddings WHERE embedding IS NOT NULL ORDER BY embedding <=> %s::vector LIMIT 5', (query_embedding,)).fetchall()
|
||||
docs = []
|
||||
for row in rows:
|
||||
docs.append({ 'key': row['key'], 'content': row['content']})
|
||||
return docs
|
||||
|
||||
@mcp.tool
|
||||
def get_cave_location(cave: str, state: str, county: str) -> list[dict]:
|
||||
"""Lookup cave location as coordinates. Returns up to 5 matches, ordered by most to least relevant."""
|
||||
return search(f'{cave} Location, latitude, Longitude. Located in {state} and {county} county.')
|
||||
|
||||
@mcp.tool
|
||||
def general_caving_information(query: str) -> list[dict]:
|
||||
"""General purpose endpoint for any topic related to caves. Returns up to 5 mates, orderd by most to least relevant."""
|
||||
return search(query)
|
||||
|
||||
# Add a protected tool to test authentication
|
||||
@mcp.tool
|
||||
async def get_token_info() -> dict:
|
||||
"""Returns information about the Auth0 token."""
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
|
||||
token = get_access_token()
|
||||
|
||||
return {
|
||||
"issuer": token.claims.get("iss"),
|
||||
"audience": token.claims.get("aud"),
|
||||
"scope": token.claims.get("scope")
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run(transport='http', host='::1', port=9031)
|
||||
27
mcp/test/client.py
Normal file
27
mcp/test/client.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import asyncio
|
||||
from fastmcp import Client
|
||||
|
||||
client = Client("http://[::1]:8031/mcp")
|
||||
|
||||
async def test_get_cave_location(cave: str, state: str, county: str):
|
||||
async with client:
|
||||
resp = await client.call_tool("get_cave_location", {"cave": cave, "state": state, "county": county})
|
||||
print()
|
||||
print(cave)
|
||||
for item in resp.structured_content['result']:
|
||||
print(item)
|
||||
|
||||
async def test_general_caving_information(query: str):
|
||||
async with client:
|
||||
resp = await client.call_tool("general_caving_information", {"query": query})
|
||||
print()
|
||||
print(query)
|
||||
for item in resp.structured_content['result']:
|
||||
print(item)
|
||||
|
||||
asyncio.run(test_get_cave_location("Nellies Cave", "VA", "Montgomery"))
|
||||
asyncio.run(test_get_cave_location("links cave", "VA", "Giles"))
|
||||
#asyncio.run(test_get_cave_location("new river", "VA", "Giles"))
|
||||
#asyncio.run(test_get_cave_location("tawneys", "VA", "Giles"))
|
||||
#asyncio.run(test_get_cave_location("staty fork", "WV", "Pocahontas"))
|
||||
#asyncio.run(test_general_caving_information("broken sunnto"))
|
||||
1483
mcp/uv.lock
generated
Normal file
1483
mcp/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,53 +0,0 @@
|
||||
from pgvector.psycopg import register_vector, Bit
|
||||
from psycopg.rows import dict_row
|
||||
from urllib.parse import unquote
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
import anthropic
|
||||
import cohere
|
||||
import dotenv
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import minio
|
||||
import numpy as np
|
||||
import os
|
||||
import psycopg
|
||||
import time
|
||||
import logging
|
||||
from pythonjsonlogger.json import JsonFormatter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
logHandler = logging.StreamHandler()
|
||||
formatter = JsonFormatter("{asctime}{message}", style="{")
|
||||
logHandler.setFormatter(formatter)
|
||||
logger.addHandler(logHandler)
|
||||
|
||||
#####
|
||||
|
||||
dotenv.load_dotenv('/home/paul/scripts-private/lech/cavepedia-v2/poller.env')
|
||||
|
||||
COHERE_API_KEY = os.getenv('COHERE_API_KEY')
|
||||
MINIO_ACCESS_KEY = os.getenv('MINIO_ACCESS_KEY')
|
||||
MINIO_SECRET_KEY = os.getenv('MINIO_SECRET_KEY')
|
||||
|
||||
s3 = minio.Minio(
|
||||
's3.bigcavemaps.com',
|
||||
access_key=MINIO_ACCESS_KEY,
|
||||
secret_key=MINIO_SECRET_KEY,
|
||||
region='kansascity',
|
||||
)
|
||||
|
||||
def getobject():
|
||||
bucket = 'cavepedia-v2'
|
||||
key = 'public/var/fyi/VAR-FYI 1982-01.pdf'
|
||||
with s3.get_object(bucket, key) as obj:
|
||||
with open('/tmp/file.pdf', 'wb') as f:
|
||||
while True:
|
||||
chunk = obj.read(1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
|
||||
if __name__ == '__main__':
|
||||
getobject()
|
||||
319
poller/main.py
319
poller/main.py
@@ -1,8 +1,10 @@
|
||||
from cohere.core.api_error import ApiError
|
||||
from pgvector.psycopg import register_vector, Bit
|
||||
from psycopg.rows import dict_row
|
||||
from urllib.parse import unquote
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
import anthropic
|
||||
import boto3
|
||||
import cohere
|
||||
import dotenv
|
||||
import datetime
|
||||
@@ -25,22 +27,23 @@ logger.addHandler(logHandler)
|
||||
|
||||
#####
|
||||
|
||||
dotenv.load_dotenv('/home/paul/scripts-private/lech/cavepedia-v2/poller.env')
|
||||
dotenv.load_dotenv('/home/pew/scripts-private/loser/cavepedia-v2/poller.env')
|
||||
|
||||
COHERE_API_KEY = os.getenv('COHERE_API_KEY')
|
||||
MINIO_ACCESS_KEY = os.getenv('MINIO_ACCESS_KEY')
|
||||
MINIO_SECRET_KEY = os.getenv('MINIO_SECRET_KEY')
|
||||
S3_ACCESS_KEY = os.getenv('S3_ACCESS_KEY')
|
||||
S3_SECRET_KEY = os.getenv('S3_SECRET_KEY')
|
||||
|
||||
s3 = minio.Minio(
|
||||
's3.bigcavemaps.com',
|
||||
access_key=MINIO_ACCESS_KEY,
|
||||
secret_key=MINIO_SECRET_KEY,
|
||||
region='kansascity',
|
||||
s3 = boto3.client(
|
||||
's3',
|
||||
aws_access_key_id=S3_ACCESS_KEY,
|
||||
aws_secret_access_key=S3_SECRET_KEY,
|
||||
endpoint_url='https://s3.bigcavemaps.com',
|
||||
region_name='eu',
|
||||
)
|
||||
co = cohere.ClientV2(COHERE_API_KEY)
|
||||
co = cohere.ClientV2(api_key=COHERE_API_KEY)
|
||||
conn = psycopg.connect(
|
||||
host='127.0.0.1',
|
||||
port=4010,
|
||||
host='::1',
|
||||
port=9030,
|
||||
dbname='cavepediav2_db',
|
||||
user='cavepediav2_user',
|
||||
password='cavepediav2_pw',
|
||||
@@ -51,14 +54,33 @@ conn = psycopg.connect(
|
||||
# events table is created by minio up creation of event destination
|
||||
def create_tables():
|
||||
commands = (
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS metadata (
|
||||
id SERIAL PRIMARY KEY,
|
||||
bucket TEXT,
|
||||
key TEXT,
|
||||
split BOOLEAN DEFAULT FALSE,
|
||||
UNIQUE(bucket, key)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS batches (
|
||||
id SERIAL PRIMARY KEY,
|
||||
platform TEXT,
|
||||
batch_id TEXT,
|
||||
type TEXT,
|
||||
done BOOLEAN DEFAULT FALSE
|
||||
)
|
||||
""",
|
||||
"CREATE EXTENSION IF NOT EXISTS vector",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS embeddings (
|
||||
id SERIAL PRIMARY KEY,
|
||||
bucket TEXT,
|
||||
key TEXT,
|
||||
content TEXT,
|
||||
embedding vector(1536),
|
||||
PRIMARY KEY (bucket, key)
|
||||
UNIQUE(bucket, key)
|
||||
)
|
||||
""")
|
||||
for command in commands:
|
||||
@@ -66,53 +88,118 @@ def create_tables():
|
||||
conn.commit()
|
||||
register_vector(conn)
|
||||
|
||||
## splitting
|
||||
def split_pdfs():
|
||||
rows = conn.execute('SELECT * FROM events')
|
||||
def import_files():
|
||||
"""Scan import bucket for any new files; move them to the files bucket and add to db; delete from import bucket"""
|
||||
BUCKET_IMPORT = 'cavepediav2-import'
|
||||
BUCKET_FILES = 'cavepediav2-files'
|
||||
# get new files; add to db, sync to main bucket; delete from import bucket
|
||||
response = s3.list_objects_v2(Bucket=BUCKET_IMPORT)
|
||||
if 'Contents' in response:
|
||||
for obj in response['Contents']:
|
||||
if obj['Key'].endswith('/'):
|
||||
continue
|
||||
s3.copy_object(
|
||||
CopySource={'Bucket': BUCKET_IMPORT, 'Key': obj['Key']},
|
||||
Bucket=BUCKET_FILES,
|
||||
Key=obj['Key'],
|
||||
)
|
||||
conn.execute('INSERT INTO metadata (bucket, key) VALUES(%s, %s);', (BUCKET_FILES, obj['Key']))
|
||||
conn.commit()
|
||||
s3.delete_object(
|
||||
Bucket=BUCKET_IMPORT,
|
||||
Key=obj['Key'],
|
||||
)
|
||||
|
||||
def split_files():
|
||||
"""Split PDFs into single pages for easier processing"""
|
||||
BUCKET_PAGES = 'cavepediav2-pages'
|
||||
rows = conn.execute("SELECT COUNT(*) FROM metadata WHERE split = false")
|
||||
row = rows.fetchone()
|
||||
logger.info(f'Found {row["count"]} files to split.')
|
||||
rows = conn.execute('SELECT * FROM metadata WHERE split = false')
|
||||
|
||||
for row in rows:
|
||||
bucket = row['bucket']
|
||||
key = row['key']
|
||||
|
||||
with conn.cursor() as cur:
|
||||
for record in row['value']['Records']:
|
||||
bucket = record['s3']['bucket']['name']
|
||||
key = record['s3']['object']['key']
|
||||
key = unquote(key)
|
||||
key = key.replace('+',' ')
|
||||
logger.info(f'SPLITTING bucket: {bucket}, key: {key}')
|
||||
|
||||
logger.info(f'SPLITTING bucket: {bucket}, key: {key}')
|
||||
##### get pdf #####
|
||||
s3.download_file(bucket, key, '/tmp/file.pdf')
|
||||
|
||||
##### get pdf #####
|
||||
with s3.get_object(bucket, key) as obj:
|
||||
with open('/tmp/file.pdf', 'wb') as f:
|
||||
while True:
|
||||
chunk = obj.read(1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
##### split #####
|
||||
with open('/tmp/file.pdf', 'rb') as f:
|
||||
reader = PdfReader(f)
|
||||
|
||||
##### split #####
|
||||
with open('/tmp/file.pdf', 'rb') as f:
|
||||
reader = PdfReader(f)
|
||||
for i in range(len(reader.pages)):
|
||||
writer = PdfWriter()
|
||||
writer.add_page(reader.pages[i])
|
||||
|
||||
for i in range(len(reader.pages)):
|
||||
writer = PdfWriter()
|
||||
writer.add_page(reader.pages[i])
|
||||
|
||||
with io.BytesIO() as bs:
|
||||
writer.write(bs)
|
||||
bs.seek(0)
|
||||
s3.put_object(f'{bucket}-pages', f'{key}/page-{i + 1}.pdf', bs, len(bs.getvalue()))
|
||||
cur.execute('INSERT INTO embeddings (bucket, key) VALUES (%s, %s);', (f'{bucket}-pages', f'{key}/page-{i + 1}.pdf'))
|
||||
|
||||
cur.execute('DELETE FROM events WHERE key = %s', (row['key'],))
|
||||
with io.BytesIO() as bs:
|
||||
writer.write(bs)
|
||||
bs.seek(0)
|
||||
s3.put_object(
|
||||
Bucket=BUCKET_PAGES,
|
||||
Key=f'{key}/page-{i + 1}.pdf',
|
||||
Body=bs.getvalue()
|
||||
)
|
||||
cur.execute('INSERT INTO embeddings (bucket, key) VALUES (%s, %s);', (BUCKET_PAGES, f'{key}/page-{i + 1}.pdf'))
|
||||
cur.execute('UPDATE metadata SET SPLIT = true WHERE id = %s', (row['id'],));
|
||||
conn.commit()
|
||||
|
||||
## processing
|
||||
def ocr_create_message(id, bucket, key):
|
||||
"""Create message to send to claude"""
|
||||
url = s3.generate_presigned_url(
|
||||
'get_object',
|
||||
Params={
|
||||
'Bucket': bucket,
|
||||
'Key': unquote(key)
|
||||
},
|
||||
)
|
||||
|
||||
message = {
|
||||
'custom_id': f'doc-{id}',
|
||||
'params': {
|
||||
'model': 'claude-haiku-4-5',
|
||||
'max_tokens': 4000,
|
||||
'temperature': 1,
|
||||
'messages': [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{
|
||||
'type': 'document',
|
||||
'source': {
|
||||
'type': 'url',
|
||||
'url': url
|
||||
}
|
||||
},
|
||||
{
|
||||
'type': 'text',
|
||||
'text': 'Extract all text from this document. Do not include any summary or conclusions of your own.'
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
return message
|
||||
|
||||
def ocr(bucket, key):
|
||||
url = s3.presigned_get_object(bucket, unquote(key))
|
||||
"""Gets OCR content of pdfs"""
|
||||
url = s3.generate_presigned_url(
|
||||
'get_object',
|
||||
Params={
|
||||
'Bucket': bucket,
|
||||
'Key': unquote(key)
|
||||
},
|
||||
)
|
||||
|
||||
client = anthropic.Anthropic()
|
||||
message = client.messages.create(
|
||||
model='claude-sonnet-4-20250514',
|
||||
model='claude-haiku-4-5',
|
||||
max_tokens=4000,
|
||||
temperature=1,
|
||||
messages=[
|
||||
@@ -136,42 +223,115 @@ def ocr(bucket, key):
|
||||
)
|
||||
return message
|
||||
|
||||
def process_events():
|
||||
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE embedding IS NULL")
|
||||
def claude_send_batch(batch):
|
||||
"""Send a batch to claude"""
|
||||
client = anthropic.Anthropic()
|
||||
message_batch = client.messages.batches.create(
|
||||
requests=batch
|
||||
)
|
||||
|
||||
conn.execute('INSERT INTO batches (platform, batch_id, type) VALUES(%s, %s, %s);', ('claude', message_batch.id, 'ocr'))
|
||||
conn.commit()
|
||||
|
||||
logger.info(f'Sent batch_id {message_batch.id} to claude')
|
||||
|
||||
def check_batches():
|
||||
"""Check batch status"""
|
||||
rows = conn.execute("SELECT COUNT(*) FROM batches WHERE done = false")
|
||||
row = rows.fetchone()
|
||||
logger.info(f'Found {row["count"]} ready to be processed')
|
||||
|
||||
rows = conn.execute("SELECT * FROM embeddings WHERE embedding IS NULL")
|
||||
|
||||
logger.info(f'Found {row["count"]} batch(es) to process.')
|
||||
rows = conn.execute("SELECT * FROM batches WHERE done = false")
|
||||
|
||||
client = anthropic.Anthropic()
|
||||
for row in rows:
|
||||
message_batch = client.messages.batches.retrieve(
|
||||
row['batch_id'],
|
||||
)
|
||||
if message_batch.processing_status == 'ended':
|
||||
results = client.messages.batches.results(
|
||||
row['batch_id'],
|
||||
)
|
||||
with conn.cursor() as cur:
|
||||
for result in results:
|
||||
id = int(result.custom_id.split('-')[1])
|
||||
try:
|
||||
content = result.result.message.content[0].text
|
||||
cur.execute('UPDATE embeddings SET content = %s WHERE id = %s;', (content, id))
|
||||
except:
|
||||
cur.execute('UPDATE embeddings SET content = %s WHERE id = %s;', ('ERROR', id))
|
||||
cur.execute('UPDATE batches SET done = true WHERE batch_id = %s;', (row['batch_id'],))
|
||||
conn.commit()
|
||||
|
||||
def ocr_main():
|
||||
"""Checks for any non-OCR'd documents and sends them to claude in batches"""
|
||||
## claude 4 sonnet ##
|
||||
# tier 1 limit: 8k tokens/min
|
||||
# tier 2: enough
|
||||
# single pdf page: up to 2k tokens
|
||||
|
||||
# get docs where content is null
|
||||
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NULL LIMIT 1000")
|
||||
row = rows.fetchone()
|
||||
logger.info(f'Batching {row["count"]} documents to generate OCR content.')
|
||||
rows = conn.execute("SELECT * FROM embeddings WHERE content IS NULL LIMIT 1000")
|
||||
|
||||
# batch docs; set content = WIP
|
||||
batch = []
|
||||
for row in rows:
|
||||
id = row['id']
|
||||
bucket = row['bucket']
|
||||
key = row['key']
|
||||
logger.info(f'PROCESSING bucket: {bucket}, key: {key}')
|
||||
|
||||
## claude 4 sonnet ##
|
||||
# tier 1 limit: 8k tokens/min
|
||||
# tier 2: enough
|
||||
# single pdf page: up to 2k tokens
|
||||
try:
|
||||
ai_ocr = ocr(bucket, key)
|
||||
text = ai_ocr.content[0].text
|
||||
logger.info(f'Batching for OCR: {bucket}, key: {key}')
|
||||
|
||||
embedding=embed(text, 'search_document')
|
||||
conn.execute('UPDATE embeddings SET content = %s, embedding = %s::vector WHERE bucket = %s AND key = %s;', (text, embedding, bucket, key))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred: {e}")
|
||||
return True
|
||||
batch.append(ocr_create_message(id, bucket, key))
|
||||
conn.execute('UPDATE embeddings SET content = %s WHERE id = %s;', ('WIP', id))
|
||||
conn.commit()
|
||||
if len(batch) > 0:
|
||||
claude_send_batch(batch)
|
||||
|
||||
def embeddings_main():
|
||||
"""Generate embeddings"""
|
||||
rows = conn.execute("SELECT COUNT(*) FROM embeddings WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL")
|
||||
row = rows.fetchone()
|
||||
logger.info(f'Batching {row["count"]} documents to generate embeddings.')
|
||||
rows = conn.execute("SELECT id, key, bucket, content FROM embeddings WHERE content IS NOT NULL AND content != 'ERROR' AND content != 'WIP' AND embedding IS NULL")
|
||||
|
||||
for row in rows:
|
||||
logger.info(f'Generating embeddings for id: {row["id"]}, bucket: {row["bucket"]}, key: {row["key"]}')
|
||||
embedding = embed(row['content'], 'search_document')
|
||||
conn.execute('UPDATE embeddings SET embedding = %s::vector WHERE id = %s;', (embedding, row['id']))
|
||||
conn.commit()
|
||||
|
||||
# try:
|
||||
# ai_ocr = ocr(bucket, key)
|
||||
# text = ai_ocr.content[0].text
|
||||
#
|
||||
# embedding=embed(text, 'search_document')
|
||||
# conn.execute('UPDATE embeddings SET content = %s, embedding = %s::vector WHERE bucket = %s AND key = %s;', (text, embedding, bucket, key))
|
||||
# conn.commit()
|
||||
# except Exception as e:
|
||||
# logger.error(f"An unexpected error occurred: {e}")
|
||||
# return True
|
||||
|
||||
### embeddings
|
||||
def embed(text, input_type):
|
||||
resp = co.embed(
|
||||
texts=[text],
|
||||
model='embed-v4.0',
|
||||
input_type=input_type,
|
||||
embedding_types=['float'],
|
||||
)
|
||||
return resp.embeddings.float[0]
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
resp = co.embed(
|
||||
texts=[text],
|
||||
model='embed-v4.0',
|
||||
input_type=input_type,
|
||||
embedding_types=['float'],
|
||||
output_dimension=1536,
|
||||
)
|
||||
return resp.embeddings.float[0]
|
||||
except ApiError as e:
|
||||
if e.status_code == 502 and attempt < max_retries - 1:
|
||||
time.sleep(30 ** attempt) # exponential backoff
|
||||
continue
|
||||
raise Exception('cohere max retries exceeded')
|
||||
|
||||
def fix_pages():
|
||||
i = 766
|
||||
@@ -183,14 +343,13 @@ def fix_pages():
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_tables()
|
||||
|
||||
while True:
|
||||
BACKOFF = False
|
||||
import_files()
|
||||
split_files()
|
||||
check_batches()
|
||||
ocr_main()
|
||||
embeddings_main()
|
||||
|
||||
split_pdfs()
|
||||
BACKOFF = process_events()
|
||||
|
||||
if BACKOFF:
|
||||
logger.info('backoff detected, sleeping an extra 5 minutes')
|
||||
time.sleep(5 * 60)
|
||||
logger.info('sleeping 5 minutes')
|
||||
time.sleep(5 * 60)
|
||||
|
||||
@@ -6,6 +6,7 @@ readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"anthropic>=0.52.0",
|
||||
"boto3>=1.42.4",
|
||||
"cohere>=5.15.0",
|
||||
"minio>=7.2.15",
|
||||
"mypy>=1.15.0",
|
||||
|
||||
74
poller/uv.lock
generated
74
poller/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.11"
|
||||
|
||||
[[package]]
|
||||
@@ -76,6 +76,34 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/e4/bf8034d25edaa495da3c8a3405627d2e35758e44ff6eaa7948092646fdcc/argon2_cffi_bindings-21.2.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e415e3f62c8d124ee16018e491a009937f8cf7ebf5eb430ffc5de21b900dad93", size = 53104, upload-time = "2021-12-01T09:09:31.335Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.42.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore" },
|
||||
{ name = "jmespath" },
|
||||
{ name = "s3transfer" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/31/246916eec4fc5ff7bebf7e75caf47ee4d72b37d4120b6943e3460956e618/boto3-1.42.4.tar.gz", hash = "sha256:65f0d98a3786ec729ba9b5f70448895b2d1d1f27949aa7af5cb4f39da341bbc4", size = 112826, upload-time = "2025-12-05T20:27:14.931Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/00/25/9ae819385aad79f524859f7179cecf8ac019b63ac8f150c51b250967f6db/boto3-1.42.4-py3-none-any.whl", hash = "sha256:0f4089e230d55f981d67376e48cefd41c3d58c7f694480f13288e6ff7b1fefbc", size = 140621, upload-time = "2025-12-05T20:27:12.803Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.42.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jmespath" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5c/b7/dec048c124619b2702b5236c5fc9d8e5b0a87013529e9245dc49aaaf31ff/botocore-1.42.4.tar.gz", hash = "sha256:d4816023492b987a804f693c2d76fb751fdc8755d49933106d69e2489c4c0f98", size = 14848605, upload-time = "2025-12-05T20:27:02.919Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/a2/7b50f12a9c5a33cd85a5f23fdf78a0cbc445c0245c16051bb627f328be06/botocore-1.42.4-py3-none-any.whl", hash = "sha256:c3b091fd33809f187824b6434e518b889514ded5164cb379358367c18e8b0d7d", size = 14519938, upload-time = "2025-12-05T20:26:58.881Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.4.26"
|
||||
@@ -414,6 +442,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/4a/4175a563579e884192ba6e81725fc0448b042024419be8d83aa8a80a3f44/jiter-0.10.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa96f2abba33dc77f79b4cf791840230375f9534e5fac927ccceb58c5e604a5", size = 354213, upload-time = "2025-05-18T19:04:41.894Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jmespath"
|
||||
version = "1.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/00/2a/e867e8531cf3e36b41201936b7fa7ba7b5702dbef42922193f05c8976cd6/jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe", size = 25843, upload-time = "2022-06-17T18:00:12.224Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minio"
|
||||
version = "7.2.15"
|
||||
@@ -545,6 +582,7 @@ version = "0.1.0"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "anthropic" },
|
||||
{ name = "boto3" },
|
||||
{ name = "cohere" },
|
||||
{ name = "minio" },
|
||||
{ name = "mypy" },
|
||||
@@ -559,6 +597,7 @@ dependencies = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "anthropic", specifier = ">=0.52.0" },
|
||||
{ name = "boto3", specifier = ">=1.42.4" },
|
||||
{ name = "cohere", specifier = ">=5.15.0" },
|
||||
{ name = "minio", specifier = ">=7.2.15" },
|
||||
{ name = "mypy", specifier = ">=1.15.0" },
|
||||
@@ -756,6 +795,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/4e/931b90b51e3ebc69699be926b3d5bfdabae2d9c84337fd0c9fb98adbf70c/pypdf-5.5.0-py3-none-any.whl", hash = "sha256:2f61f2d32dde00471cd70b8977f98960c64e84dd5ba0d070e953fcb4da0b2a73", size = 303371, upload-time = "2025-05-11T14:00:40.064Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "six" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dotenv"
|
||||
version = "1.1.0"
|
||||
@@ -824,6 +875,27 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928, upload-time = "2024-05-29T15:37:47.027Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "s3transfer"
|
||||
version = "0.16.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "six"
|
||||
version = "1.17.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.1"
|
||||
|
||||
Reference in New Issue
Block a user