#!/usr/bin/env python3
"""
SQS to Milvus Candidate Processing Script with FastAPI Scheduler

This script integrates with FastAPI and uses APScheduler to automatically
process AWS SQS messages in batches with multiprocessing for Milvus insertion.
"""

import json
import boto3
import pandas as pd
import time
import sqlite3
import asyncio
from typing import List, Dict, Any, Optional
from pymilvus import connections, Collection, utility, db
import logging
from langchain_openai import OpenAIEmbeddings
import os
from dotenv import load_dotenv
from fastapi import FastAPI
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from multiprocessing import Pool, cpu_count
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
import threading
from contextlib import asynccontextmanager

load_dotenv()

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Global flags and caches
is_candidate_sync_running = False
_milvus_connections = {}
_embeddings_cache = {}
_connection_lock = threading.Lock()
scheduler = None

def get_milvus_connection(connection_config):
    """Get or create a Milvus connection with connection pooling"""
    process_id = os.getpid()
    connection_key = f"{process_id}_{connection_config['milvus_host']}_{connection_config['milvus_port']}"
    
    with _connection_lock:
        if connection_key not in _milvus_connections:
            try:
                # Create new connection for this process
                alias = f"process_{process_id}"
                
                # Check if connection already exists
                if not connections.has_connection(alias):
                    connections.connect(
                        alias=alias,
                        host=connection_config['milvus_host'],
                        port=connection_config['milvus_port']
                    )
                    logger.info(f"Created new Milvus connection for process {process_id}")
                
                # Set database using the connection alias
                db.using_database(connection_config['database_name'], using=alias)
                collection = Collection(connection_config['collection_name'], using=alias)
                
                _milvus_connections[connection_key] = {
                    'alias': alias,
                    'collection': collection
                }
                
            except Exception as e:
                logger.error(f"Failed to create Milvus connection: {e}")
                raise
        
        return _milvus_connections[connection_key]

def get_embeddings():
    """Get or create embeddings model with caching"""
    process_id = os.getpid()
    
    if process_id not in _embeddings_cache:
        try:
            _embeddings_cache[process_id] = OpenAIEmbeddings(
                model="text-embedding-ada-002",
                api_key=os.getenv("OPENAI_API_KEY")
            )
            logger.info(f"Created embeddings model for process {process_id}")
        except Exception as e:
            logger.error(f"Failed to create embeddings model: {e}")
            raise
    
    return _embeddings_cache[process_id]

def process_candidate_batch(batch_data):
    """
    Process a batch of candidates in a separate process with connection pooling
    Updates existing candidates or creates new ones
    """
    try:
        # Get reusable connection and embeddings
        milvus_conn = get_milvus_connection(batch_data)
        collection = milvus_conn['collection']
        embeddings = get_embeddings()
        
        candidates_for_milvus = []
        candidates_to_update = []
        batch_candidates = batch_data['candidates']
        
        for candidate_row in batch_candidates:
            try:
                row = candidate_row['result'][0]
                candidate_id = row.get("candidate_id")
                
                if not candidate_id:
                    continue

                # Check if candidate already exists
                try:
                    existing_records = collection.query(
                        expr=f"candidate_id == {candidate_id}",
                        output_fields=["candidate_id"],
                        limit=1
                    )
                    
                    candidate_exists = len(existing_records) > 0
                    
                except Exception as query_error:
                    logger.error(f"Failed to check candidate existence for ID {candidate_id}: {query_error}")
                    continue

                # Structured fields
                candidate_data = {
                    "candidate_id": int(candidate_id),
                    "headline": str(row.get("headline") or ""),
                    "about": str(row.get("about") or ""),
                    "sectors": str(row.get("sectors") or ""),
                    "keywords": str(row.get("keywords") or ""),
                    "skills": str(row.get("skills") or ""),
                    "languages": str(row.get("languages") or ""),
                    "projects": str(row.get("projects") or ""),
                    "certifications": str(row.get("certifications") or ""),
                }

                # Build combined text for embedding
                text_parts = [
                    str(v) for k, v in row.items()
                    if k != "candidate_id" and v is not None
                ]
                combined_text = " ".join(text_parts)
                cleaned_text = clean_resume_text(combined_text)

                # Generate embedding
                candidate_data["embedding"] = embeddings.embed_query(cleaned_text)

                if candidate_exists:
                    # Mark for update
                    candidates_to_update.append(candidate_data)
                    logger.info(f"Candidate {candidate_id} exists, marking for update")
                else:
                    # Mark for creation
                    candidates_for_milvus.append(candidate_data)
                    logger.info(f"Candidate {candidate_id} is new, marking for creation")

            except Exception as e:
                logger.error(f"Error processing candidate in batch: {e}")
                continue
        
        # Process updates first (delete + insert)
        updated_count = 0
        for candidate_data in candidates_to_update:
            try:
                candidate_id = candidate_data["candidate_id"]
                
                # Delete existing record
                delete_result = collection.delete(expr=f"candidate_id == {candidate_id}")
                logger.info(f"Deleted old candidate record for ID {candidate_id}: {delete_result}")
                
                # Insert updated record
                insert_result = collection.insert([candidate_data])
                logger.info(f"Inserted updated candidate record for ID {candidate_id}: {insert_result}")
                
                remove_candidate_entries_from_sqlite(candidate_id)

                updated_count += 1
                
            except Exception as e:
                logger.error(f"Error updating candidate {candidate_data.get('candidate_id')}: {e}")
                continue
        
        # Process new candidates
        created_count = 0
        if candidates_for_milvus:
            try:
                # Insert batch into Milvus
                insert_result = collection.insert(candidates_for_milvus)
                created_count = len(candidates_for_milvus)
                logger.info(f"Inserted {created_count} new candidates: {insert_result}")
                
            except Exception as e:
                logger.error(f"Error inserting new candidates: {e}")
        
        # Flush and reload collection to make changes available
        if updated_count > 0 or created_count > 0:
            collection.flush()
            collection.load()
            logger.info("Collection flushed and reloaded")
        
        total_processed = updated_count + created_count
        
        if total_processed > 0:
            return {
                'success': True,
                'processed': total_processed,
                'created': created_count,
                'updated': updated_count,
                'batch_id': batch_data.get('batch_id', 'unknown')
            }
        
        return {
            'success': False,
            'processed': 0,
            'created': 0,
            'updated': 0,
            'batch_id': batch_data.get('batch_id', 'unknown'),
            'error': 'No valid candidates processed in batch'
        }
        
    except Exception as e:
        logger.error(f"Error in process_candidate_batch: {e}")
        return {
            'success': False,
            'processed': 0,
            'created': 0,
            'updated': 0,
            'batch_id': batch_data.get('batch_id', 'unknown'),
            'error': str(e)
        }


import re
import unicodedata

def clean_resume_text(text):
    # Normalize Unicode characters
    text = unicodedata.normalize('NFKC', text)

    # Replace newlines and carriage returns with spaces
    text = re.sub(r'[\r\n]+', ' ', text)

    # Replace multiple spaces with a single space
    text = re.sub(r'\s{2,}', ' ', text)

    # Replace various separators with standardized punctuation
    text = re.sub(r'[\|%]+', ' | ', text)  # pipes and % become separators
    text = re.sub(r',(?=\S)', ', ', text)  # ensure space after commas
    text = re.sub(r'\s*,\s*', ', ', text)  # clean up comma spacing

    # Remove redundant punctuation and extra spaces again
    text = re.sub(r'\s{2,}', ' ', text)
    text = re.sub(r'\s*\|\s*', ' | ', text)  # standardize pipe spacing

    # Deduplicate comma-separated and pipe-separated sections
    def deduplicate_sections(text):
        def dedup(s, sep):
            seen = set()
            items = []
            for item in [i.strip() for i in s.split(sep)]:
                key = item.lower()
                if key and key not in seen:
                    seen.add(key)
                    items.append(item)
            return f" {sep} ".join(items)

        # Deduplicate within comma-separated and pipe-separated blocks
        text = ' | '.join([dedup(block, ',') for block in text.split('|')])
        return text

    text = deduplicate_sections(text)

    # Final cleanup: strip leading/trailing spaces and fix spacing
    return text.strip()


DB_FILE = "recommendations.db"

def remove_candidate_entries_from_sqlite(candidate_id: int):
    """
    Remove all entries in the recommendations table related to the given candidate_id.
    If the database does not exist, do nothing.
    """
    try:
        if not os.path.exists(DB_FILE):
            # Database file doesn't exist, skip
            return {"status": "skipped", "message": "Database file not found"}

        conn = sqlite3.connect(DB_FILE)
        cursor = conn.cursor()

        cursor.execute(
            "DELETE FROM recommendations WHERE candidate_id = ?",
            (candidate_id,)
        )
        conn.commit()
        deleted_rows = cursor.rowcount
        conn.close()

        return {"status": "success", "deleted_rows": deleted_rows}

    except Exception as e:
         return {"status": "skipped", "message": "found some issue in candidate remove from sqlite db"}


class SQSMilvusProcessor:
    def __init__(self, 
                 sqs_queue_url: str,
                 milvus_host: str = "localhost",
                 milvus_port: int = 19530,
                 collection_name: str = "candidates",
                 database_name: str = "default",
                 aws_region: str = "us-east-1",
                 batch_size: int = 100,
                 max_workers: int = None):
        """
        Initialize the SQS to Milvus processor
        """
        self.sqs_queue_url = sqs_queue_url
        self.milvus_host = milvus_host
        self.milvus_port = milvus_port
        self.collection_name = collection_name
        self.database_name = database_name
        self.aws_region = aws_region
        self.batch_size = batch_size
        self.max_workers = max_workers or min(cpu_count(), 4)
        
        # Initialize AWS SQS client
        self.sqs = boto3.client(
            'sqs',
            region_name=self.aws_region,
            aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
            aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY")
        )
        
        # Main process connection for validation
        self.main_collection = None
        self._main_connection_established = False
        
    def ensure_main_connection(self):
        """Ensure main process has a connection for validation purposes"""
        if not self._main_connection_established:
            try:
                # Check if connection already exists, disconnect if it does
                if connections.has_connection("main"):
                    try:
                        connections.disconnect("main")
                        logger.info("Disconnected existing main connection")
                    except Exception as e:
                        logger.warning(f"Error disconnecting existing main connection: {e}")
                
                # Create new connection
                connections.connect(
                    alias="main",
                    host=self.milvus_host,
                    port=self.milvus_port
                )
                logger.info(f"Created main connection to Milvus at {self.milvus_host}:{self.milvus_port}")
                
                # Set database using the connection alias
                db.using_database(self.database_name, using="main")
                logger.info(f"Using database: {self.database_name}")
                
                # Check if collection exists
                if not utility.has_collection(self.collection_name, using="main"):
                    logger.error(f"Collection '{self.collection_name}' does not exist in database '{self.database_name}'")
                    raise Exception(f"Collection '{self.collection_name}' not found")
                
                # Create collection object with explicit connection
                self.main_collection = Collection(self.collection_name, using="main")
                self._main_connection_established = True
                logger.info(f"Main process connected to Milvus collection: {self.collection_name}")
                
            except Exception as e:
                logger.error(f"Failed to establish main connection to Milvus: {e}")
                # Clean up on failure
                try:
                    if connections.has_connection("main"):
                        connections.disconnect("main")
                except:
                    pass
                self._main_connection_established = False
                raise
    
    def get_all_sqs_messages(self) -> List[Dict]:
        """
        Get all available messages from SQS queue
        Returns all messages currently in the queue
        """
        all_messages = []
        
        while True:
            try:
                response = self.sqs.receive_message(
                    QueueUrl=self.sqs_queue_url,
                    MaxNumberOfMessages=10,  # SQS max per request
                    WaitTimeSeconds=1,  # Short wait to get available messages
                    MessageAttributeNames=['All']
                )
                
                messages = response.get('Messages', [])
                if not messages:
                    break
                    
                all_messages.extend(messages)
                logger.info(f"Retrieved {len(messages)} messages, total: {len(all_messages)}")
                
                # Small delay to avoid overwhelming SQS
                time.sleep(0.1)
                
            except Exception as e:
                logger.error(f"Error retrieving SQS messages: {e}")
                break
        
        logger.info(f"Total messages retrieved from SQS: {len(all_messages)}")
        return all_messages
    
    def parse_sqs_message(self, message: Dict) -> Optional[Dict]:
        """Parse SQS message to extract candidate data"""
        try:
            body = json.loads(message['Body'])
            
            if 'candidate_data' in body:
                return body['candidate_data']
            elif 'Records' in body:
                return body['Records'][0] if body['Records'] else None
            else:
                return body
                
        except (json.JSONDecodeError, KeyError) as e:
            logger.error(f"Failed to parse SQS message: {e}")
            return None
    
    def delete_sqs_messages_batch(self, messages: List[Dict]):
        """Delete multiple messages from SQS in batches"""
        try:
            # SQS allows up to 10 messages per delete batch
            for i in range(0, len(messages), 10):
                batch = messages[i:i+10]
                entries = []
                
                for idx, message in enumerate(batch):
                    entries.append({
                        'Id': str(idx),
                        'ReceiptHandle': message['ReceiptHandle']
                    })
                
                if entries:
                    response = self.sqs.delete_message_batch(
                        QueueUrl=self.sqs_queue_url,
                        Entries=entries
                    )
                    
                    successful = len(response.get('Successful', []))
                    failed = len(response.get('Failed', []))
                    logger.info(f"Deleted {successful} messages, {failed} failed")
                    
        except Exception as e:
            logger.error(f"Error deleting SQS messages: {e}")
    
    def create_candidate_batches(self, candidates: List[Dict]) -> List[List[Dict]]:
        """Split candidates into batches for parallel processing"""
        batches = []
        for i in range(0, len(candidates), self.batch_size):
            batch = candidates[i:i + self.batch_size]
            batches.append(batch)
        return batches
    
    def process_all_candidates_multiprocess(self, candidates: List[Dict]) -> bool:
        """
        Process all candidates using multiprocessing
        """
        global is_candidate_sync_running
        
        if is_candidate_sync_running:
            logger.warning("Candidate sync already running, skipping...")
            return False
        
        if not candidates:
            logger.info("No candidates to process")
            return False
        
        is_candidate_sync_running = True
        
        try:
            logger.info(f"Processing {len(candidates)} candidates with multiprocessing")
            
            # Create batches
            candidate_batches = self.create_candidate_batches(candidates)
            logger.info(f"Created {len(candidate_batches)} batches")
            
            # Prepare batch data for multiprocessing with connection config
            batch_data_list = []
            for i, batch in enumerate(candidate_batches):
                batch_data = {
                    'candidates': batch,
                    'milvus_host': self.milvus_host,
                    'milvus_port': self.milvus_port,
                    'collection_name': self.collection_name,
                    'database_name': self.database_name,
                    'batch_id': i
                }
                batch_data_list.append(batch_data)
            
            # Process batches in parallel
            successful_batches = 0
            total_processed = 0
            
            with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
                results = list(executor.map(process_candidate_batch, batch_data_list))
                
                for result in results:
                    if result['success']:
                        successful_batches += 1
                        total_processed += result['processed']
                        logger.info(f"Batch {result['batch_id']} processed {result['processed']} candidates")
                    else:
                        logger.error(f"Batch {result['batch_id']} failed: {result.get('error', 'Unknown error')}")
            
            logger.info(f"Completed processing: {successful_batches}/{len(candidate_batches)} batches successful")
            logger.info(f"Total candidates processed: {total_processed}")
            
            return successful_batches > 0
            
        except Exception as e:
            logger.error(f"Error in multiprocess candidate processing: {e}")
            return False
        finally:
            is_candidate_sync_running = False
    
    async def process_sqs_queue(self):
        """Main processing method for SQS queue"""
        try:
            logger.info("Starting SQS queue processing...")
            
            # Ensure main connection is established
            try:
                self.ensure_main_connection()
            except Exception as e:
                logger.error(f"Failed to establish Milvus connection: {e}")
                return
            
            # Get all available messages
            messages = self.get_all_sqs_messages()
            
            if not messages:
                logger.info("No messages in queue")
                return
            
            # Parse all messages to extract candidate data
            all_candidates = []
            valid_messages = []
            
            for message in messages:
                candidate_data = self.parse_sqs_message(message)
                if candidate_data:
                    if isinstance(candidate_data, dict):
                        candidate_data = [candidate_data]
                    all_candidates.extend(candidate_data)
                    valid_messages.append(message)
            
            if not all_candidates:
                logger.warning("No valid candidate data found in messages")
                return
            
            logger.info(f"Extracted {len(all_candidates)} candidates from {len(valid_messages)} messages")
            
            # Process candidates using multiprocessing
            success = self.process_all_candidates_multiprocess(all_candidates)
            
            if success:
                # Delete processed messages from SQS
                self.delete_sqs_messages_batch(valid_messages)
                logger.info("Successfully processed and deleted messages from SQS")
            else:
                logger.error("Failed to process candidates, messages not deleted")
                
        except Exception as e:
            logger.error(f"Error in SQS queue processing: {e}")

# FastAPI Application with Lifespan Events
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    global scheduler
    logger.info("Starting FastAPI application with scheduler...")
    
    # Initialize scheduler
    scheduler = AsyncIOScheduler()
    
    # Initialize processor
    config = {
        "sqs_queue_url": "https://sqs.us-east-1.amazonaws.com/649084263747/hiregroww-prod-queue",
        "milvus_host": os.getenv("MILVUS_HOST","localhost"),
        "milvus_port": os.getenv("MILVUS_PORT", "19530"),
        "collection_name": os.getenv("CANDIDATE_RECOMMANDATION_COLLECTION_NAME", "candidate_embeddings"),
        "database_name": os.getenv("MILVUS_DB_NAME","hirenest_recommandation"),
        "aws_region": os.getenv("AWS_DEFAULT_REGION","us-east-1"),
        "batch_size": 10,  # Adjust batch size as needed
        "max_workers": 4   # Adjust based on your system
    }
    
    processor = SQSMilvusProcessor(**config)
    
    # Store processor in app state for potential future use
    app.state.processor = processor
    
    # Schedule the job to run every 30  seconds

    seconds = int(os.getenv("SQS_POLL_INTERVAL", 30))
    scheduler.add_job(
        func=processor.process_sqs_queue,
        trigger=IntervalTrigger(seconds=seconds),  
        id='sqs_processing_job',
        name='Process SQS Queue',
        replace_existing=True,
        max_instances=1  # Prevent multiple instances running simultaneously
    )
    
    # Start scheduler
    scheduler.start()
    
    logger.info("FastAPI application started with SQS processing scheduler")
    logger.info(f"Scheduler will run every {seconds} seconds")
    
    yield
    
    # Shutdown
    logger.info("Shutting down FastAPI application...")
    
    if scheduler and scheduler.running:
        scheduler.shutdown(wait=True)
        logger.info("Scheduler stopped")
    
    # Clean up connections
    global _milvus_connections, _embeddings_cache
    
    # Clean up main connection first
    try:
        if connections.has_connection("main"):
            connections.disconnect("main")
            logger.info("Disconnected main Milvus connection")
    except Exception as e:
        logger.warning(f"Error disconnecting main connection: {e}")
    
    # Clean up process connections
    for conn_key, conn_info in _milvus_connections.items():
        try:
            alias = conn_info['alias']
            if connections.has_connection(alias):
                connections.disconnect(alias)
                logger.info(f"Disconnected Milvus connection: {conn_key}")
        except Exception as e:
            logger.warning(f"Error disconnecting {conn_key}: {e}")
    
    _milvus_connections.clear()
    _embeddings_cache.clear()
    logger.info("FastAPI application shutdown complete")

# Create FastAPI app with lifespan
app = FastAPI(
    title="SQS to Milvus Processor",
    description="Automated SQS to Milvus candidate processing with scheduler",
    version="1.0.0",
    lifespan=lifespan
)

# Health check endpoint (optional - just for monitoring)
@app.get("/")
async def root():
    """Basic health check endpoint"""
    return {
        "message": "SQS to Milvus Processor is running",
        "scheduler_status": "running" if scheduler and scheduler.running else "stopped",
        "sync_status": "running" if is_candidate_sync_running else "idle"
    }

# Optional: Add logging endpoint for monitoring
@app.get("/logs/status")
async def get_processing_status():
    """Get current processing status"""
    return {
        "sync_running": is_candidate_sync_running,
        "scheduler_running": scheduler.running if scheduler else False,
        "active_connections": len(_milvus_connections),
        "scheduled_jobs": len(scheduler.get_jobs()) if scheduler else 0
    }

if __name__ == "__main__":
    import uvicorn
    
    # Run the FastAPI application
    uvicorn.run(
        "milvus_sqs_jobs:app",  # Replace "main" with your actual filename
        host="0.0.0.0",
        port=8024,
        reload=False,  # Set to True for development
        workers=1,  # Keep as 1 for scheduler to work properly
        log_level="info"
    )
