Mulit-agents in LangGraph for Investment Analysis

Category: LLM Agents
Topic: LLM Agents LLMs LangGraph

Published by Nicole on Nov 21, 2024 • 10 min read.

In this post you will learn how to do investment analysis with a multi-agent-setup. For this, you will be using the following tools and will also learn the following: 

  • Exa, after account login, get your API key here. To find the exact content you're looking for on the web using embeddings-based search.
  • SerpApi here, after account login, get your API key to do look for existing patents.
  • Python REPL, please note that Python REPL can execute arbitrary code on the host machine (e.g., delete files, make network requests). Use with caution.
  • Tools to access and write to a .txt file and create a plot of historical prices.
  • How to define utilities to help create the graph.

  • How to create a team supervisor and the team of agents.

 The full code can be found in this notebook, as I'm going to cover only key aspects here. 

 

First, create some tools for our agents, as the initial tool, we use the PythonREPL from LangChain:

 

repl = PythonREPL()
# Warning: This executes code locally, which can be unsafe when not sandboxed

@tool
def python_repl(
    code: Annotated[str, "The python code to execute to generate your chart."],
):
    """Use this to execute python code. If you want to see the output of a value,
    you should print it out with `print(...)`. This is visible to the user."""
    try:
        result = repl.run(code)
    except BaseException as e:
        return f"Failed to execute. Error: {repr(e)}"
    result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
    return (
        result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
    )

This will run code and produce the code accoringly from the input of the agent. You will use this to create a plot of the historical stock price. However, a note of caution: PythonREPL executes your code locally, this can introduce risks if you are not sandboxing it.

Next, you can create some tools to just search the web as follows:
 

@tool("finance_search")
def finance_search(query: str) -> str:
    """Search with Google SERP API by a query to Search Google for general
    information related to finance and stocks about a given topic."""
    params = {
        "engine": "google",
        "gl": "us",
        "hl": "en",
        }
    finance_search = SerpAPIWrapper(params=params, serpapi_api_key=serp_api_key)
    return finance_search.run(query)


@tool("exa_search")
def exa_search(question: str) -> str:
    """Tool using Exa's Python SDK to run semantic search and return result highlights."""
    exa = Exa(exa_api_key)

    response = exa.search_and_contents(
        question,
        type="neural",
        use_autoprompt=True,
        num_results=3,
        highlights=True
    )

    results = []
    for idx, eachResult in enumerate(response.results):
        result = {
            "Title": eachResult.title,
            "URL": eachResult.url,
            "Highlight": "".join(eachResult.highlights)
        }
        results.append(result)

    return json.dumps(results)

This setup provides two tools for financial data retrieval:

  • Finance Search: Queries Google using the SERP API to fetch general finance and stock-related information.
  • Exa Search: Uses Exa's SDK for semantic search to deliver context-rich results with highlights, making it ideal for nuanced financial queries.

Now, to do the analysis, we create the functions to do the technical analysis:

 

def compute_tech_indicators(data, value):
    # Compute 10 and 30 days Moving Averages
    data['ma10'] = data[value].rolling(window=10).mean()
    data['ma30'] = data[value].rolling(window=30).mean()

    # Create MACD with shorter windows
    data['13ema'] = data[value].ewm(span=13).mean()
    data['6ema'] = data[value].ewm(span=6).mean()
    data['MACD'] = data['6ema'] - data['13ema']

    # Ensure `MACD` calculation is done before creating `MACD_Signal`
    data['MACD_Signal'] = create_MACD_signal(data['MACD'])

    # Create Exponential Moving Average (shorter window)
    data['ema'] = data[value].ewm(com=0.3).mean()

    # Create Momentum
    data['momentum'] = (data[value] / 100) - 1

    # RSI
    delta = data[value].diff()
    gain = delta.where(delta > 0, 0).rolling(window=7).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(window=7).mean()
    rs = gain / loss
    data['RSI'] = 100 - (100 / (1 + rs))

    return data

def create_MACD_signal(macd_series):
    macd_sign = np.sign(macd_series)
    macd_shifted_sign = macd_sign.shift(1)
    return macd_sign * (macd_sign != macd_shifted_sign)

def generate_MA_crossing(data, value, s_window=20, l_window=50):
    data['short_MA'] = data[value].rolling(window=s_window).mean()
    data['long_MA'] = data[value].rolling(window=l_window).mean()
    data['short_MA-long_MA'] = data['short_MA'] - data['long_MA']
    data['Signal'] = np.where(data['short_MA-long_MA'] > 0, 1, 0)
    data['Signal'] = np.where(data['short_MA-long_MA'] < 0, -1, data['Signal'])
    data['Trading'] = np.sign(data['Signal'] - data['Signal'].shift(1))
    return data

These function will be used eventually from our technical analysis agent by accessing this tool:

 

@tool("yf_tech_analysis")
def yf_tech_analysis(stock_symbol: str, period: str = "1y"):
    """
        Perform a comprehensive technical analysis on the given stock symbol {stock}.

        Args:
            stock_symbol (str): The stock symbol to analyze.
            period (str, optional): The time period for analysis. Defaults to "1y".

        Returns:
            analysis_results_df: pandas dataframe, interpretation: dictionary
    """


    data = yf.download(stock_symbol, period=period)
    print("Initial data NaNs:\n", data.isna().sum())

    # Calculate indicators and diagnose NaNs
    data = compute_tech_indicators(data, 'Close')
    data = generate_MA_crossing(data, 'Close')

    # Forward-fill to handle NaNs due to rolling calculations
    data[['ma10', 'ma30', 'short_MA', 'long_MA']] = data[['ma10', 'ma30', 'short_MA', 'long_MA']].ffill()

    # Drop remaining rows with NaNs
    data = data.dropna()
    if data.empty:
        raise ValueError("Insufficient data after calculations. Increase the period or decrease indicator windows.")

    # Create DataFrame for `analysis_results`
    analysis_results_df = pd.DataFrame({
        'Current_Price': [data['Close'].iloc[-1]],
        '10_MA': [data['ma10'].iloc[-1]],
        '30_MA': [data['ma30'].iloc[-1]],
        'Short_MA': [data['short_MA'].iloc[-1]],
        'Long_MA': [data['long_MA'].iloc[-1]],
        '6_EMA': [data['6ema'].iloc[-1]],
        '13_EMA': [data['13ema'].iloc[-1]],
        'EMA': [data['ema'].iloc[-1]],
        'MACD_Value': [data['MACD'].iloc[-1]],
        'MACD_Signal': [data['MACD_Signal'].iloc[-1]],
        'RSI': [data['RSI'].iloc[-1]],
        'Momentum': [data['momentum'].iloc[-1]],
        'MA_Crossing_Signal': [data['Signal'].iloc[-1]],
        'Trading_Action': [data['Trading'].iloc[-1]]
    })

    # Convert the last values to scalars to avoid ambiguity
    latest_close = data['Close'].iloc[-1].item() if hasattr(data['Close'].iloc[-1], 'item') else data['Close'].iloc[-1]
    latest_ma30 = data['ma30'].iloc[-1].item() if hasattr(data['ma30'].iloc[-1], 'item') else data['ma30'].iloc[-1]
    latest_macd = data['MACD'].iloc[-1].item() if hasattr(data['MACD'].iloc[-1], 'item') else data['MACD'].iloc[-1]
    latest_rsi = data['RSI'].iloc[-1].item() if hasattr(data['RSI'].iloc[-1], 'item') else data['RSI'].iloc[-1]

    # Interpretation
    interpretation = {
        'Trend': 'Bullish' if latest_close > latest_ma30 else 'Bearish',
        'RSI': 'Overbought' if latest_rsi > 70 else ('Oversold' if latest_rsi < 30 else 'Neutral'),
        'MACD': 'Bullish' if latest_macd > 0 else 'Bearish'
    }

    return analysis_results_df, interpretation

To do a deeper analysis we also do a fundamental analysis, again we first create the function for it to do so:

@tool("yf_fundamental_analysis")
def yf_fundamental_analysis(ticker: str):
    """
        Perform a comprehensive fundamental analysis on the given stock symbol {stock}.

        Args:
            stock_symbol (str): The stock symbol to analyze.

        Returns:
            dict: A dictionary with the detailed fundamental analysis results.
    """
    try:
        stock = yf.Ticker(ticker)
        info = stock.info

        # Data processing
        financials = stock.financials.infer_objects(copy=False)
        balance_sheet = stock.balance_sheet.infer_objects(copy=False)
        cash_flow = stock.cashflow.infer_objects(copy=False)

        # Fill missing values
        financials = financials.ffill()
        balance_sheet = balance_sheet.ffill()
        cash_flow = cash_flow.ffill()

        # Key Ratios and Metrics
        ratios = {
            "P/E Ratio": info.get('trailingPE'),
            "Forward P/E": info.get('forwardPE'),
            "P/B Ratio": info.get('priceToBook'),
            "P/S Ratio": info.get('priceToSalesTrailing12Months'),
            "PEG Ratio": info.get('pegRatio'),
            "Debt to Equity": info.get('debtToEquity'),
            "Current Ratio": info.get('currentRatio'),
            "Quick Ratio": info.get('quickRatio'),
            "ROE": info.get('returnOnEquity'),
            "ROA": info.get('returnOnAssets'),
            "ROIC": info.get('returnOnCapital'),
            "Gross Margin": info.get('grossMargins'),
            "Operating Margin": info.get('operatingMargins'),
            "Net Profit Margin": info.get('profitMargins'),
            "Dividend Yield": info.get('dividendYield'),
            "Payout Ratio": info.get('payoutRatio'),
        }

        # Growth Rates
        revenue = financials.loc['Total Revenue']
        net_income = financials.loc['Net Income']
        revenue_growth = revenue.pct_change(periods=-1).iloc[0] if len(revenue) > 1 else None
        net_income_growth = net_income.pct_change(periods=-1).iloc[0] if len(net_income) > 1 else None

        growth_rates = {
            "Revenue Growth (YoY)": revenue_growth,
            "Net Income Growth (YoY)": net_income_growth,
        }

        # Valuation
        market_cap = info.get('marketCap')
        enterprise_value = info.get('enterpriseValue')

        valuation = {
            "Market Cap": market_cap,
            "Enterprise Value": enterprise_value,
            "EV/EBITDA": info.get('enterpriseToEbitda'),
            "EV/Revenue": info.get('enterpriseToRevenue'),
        }

        # Future Estimates
        estimates = {
            "Next Year EPS Estimate": info.get('forwardEps'),
            "Next Year Revenue Estimate": info.get('revenueEstimates', {}).get('avg'),
            "Long-term Growth Rate": info.get('longTermPotentialGrowthRate'),
        }

        # Simple DCF Valuation (very basic)
        free_cash_flow = cash_flow.loc['Free Cash Flow'].iloc[0] if 'Free Cash Flow' in cash_flow.index else None
        wacc = 0.1  # Assumed Weighted Average Cost of Capital
        growth_rate = info.get('longTermPotentialGrowthRate', 0.03)

        def simple_dcf(fcf, growth_rate, wacc, years=5):
            if fcf is None or growth_rate is None:
                return None
            terminal_value = fcf * (1 + growth_rate) / (wacc - growth_rate)
            dcf_value = sum([fcf * (1 + growth_rate) ** i / (1 + wacc) ** i for i in range(1, years + 1)])
            dcf_value += terminal_value / (1 + wacc) ** years
            return dcf_value

        dcf_value = simple_dcf(free_cash_flow, growth_rate, wacc)

        # Prepare the results
        analysis = {
            "Company Name": info.get('longName'),
            "Sector": info.get('sector'),
            "Industry": info.get('industry'),
            "Key Ratios": ratios,
            "Growth Rates": growth_rates,
            "Valuation Metrics": valuation,
            "Future Estimates": estimates,
            "Simple DCF Valuation": dcf_value,
            "Last Updated": datetime.fromtimestamp(info.get('lastFiscalYearEnd', 0)).strftime('%Y-%m-%d'),
            "Data Retrieval Date": datetime.now().strftime('%Y-%m-%d'),
        }

        # Add interpretations
        interpretations = {
            "P/E Ratio": "High P/E might indicate overvaluation or high growth expectations" if ratios.get('P/E Ratio', 0) > 20 else "Low P/E might indicate undervaluation or low growth expectations",
            "Debt to Equity": "High leverage" if ratios.get('Debt to Equity', 0) > 2 else "Conservative capital structure",
            "ROE": "Strong returns" if ratios.get('ROE', 0) > 0.15 else "Potential profitability issues",
            "Revenue Growth": "Strong growth" if growth_rates.get('Revenue Growth (YoY)', 0) > 0.1 else "Slowing growth",
        }

        analysis["Interpretations"] = interpretations

        return analysis

    except Exception as e:
        return f"An error occurred during the analysis: {str(e)}"

This will be later use by your fundamental analysis agent. I will not present here more tools, but as mentioned you will find all code in the notebook. Let us now create some agents: 

exa_search_agent = create_agent(
    llm,
    [exa_search],
    """As a seasoned investment strategist with 20 years of experience, you weave
    complex financial data into compelling investment narratives,
    your response should clearly articulate the key points you found on {{stock_symbol}} and this {{date}}.""",
)
exa_search_node = functools.partial(agent_node, agent=exa_search_agent, name="ExaSearch")


finance_search_agent = create_agent(
    llm,
    [finance_search],
    """As a seasoned investment strategist with 20 years of experience, you weave
    complex financial data into compelling investment narratives on {{stock_symbol}} and this {{date}}.""",
)
finance_search_node = functools.partial(agent_node, agent=finance_search_agent, name="FinanceSearch")


fundamental_analysis_agent = create_agent(
    llm,
    [yf_fundamental_analysis],
    """With a CFA charter and over 20 years of experience in investing,
    you dissect financial statements and identify key value drivers on {{stock_symbol}} and this {{date}}.""",
)
fundamental_analysis_node = functools.partial(agent_node, agent=fundamental_analysis_agent, name="FinanceAnalysis")

tech_analysis_agent = create_agent(
    llm,
    [yf_tech_analysis],
    """You are Chartered Market Technician (CMT) with 25 years of experience,
    you have a keen eye for market trends and patterns on {{stock_symbol}} and this {{date}}.""",
)
tech_analysis_node = functools.partial(agent_node, agent=tech_analysis_agent, name="TechnicalAnalysis")


chart_agent = create_agent(
    llm,
    [python_repl],
    """You are a Quant Developer and can write code to plot any charts requested.""",
)
chart_node = functools.partial(agent_node, agent=chart_agent, name="ChartGenerator")

doc_writer_agent = create_agent(
    llm,
    [write_document, edit_document, read_document],
    """You are a Chief Investment Strategist, who synthesize all analyses to create
    a definitive investment report on {{stock_symbol}}.
    \n""",
)

# Injects current directory working state before each call
context_aware_doc_writer_agent = prelude | doc_writer_agent
doc_writing_node = functools.partial(
    agent_node, agent=context_aware_doc_writer_agent, name="DocWriter"
)

investment_analysis_supervisor = create_team_supervisor(
    llm,
    "You are a supervisor tasked with managing a conversation between the"
    " following workers:  {team_members}. Given the following user request,"
    " respond with the worker to act next. Each worker will perform a"
    " task and respond with their results and status. When finished,"
    " respond with FINISH.",
    ["DocWriter", "ExaSearch", "FinanceSearch", "FinanceAnalysis", "TechnicalAnalysis", "ChartGenerator" ,"DocWriter"],
)

This is our agent-setup, you can clearly see, that every agent has it's own tool and it's own task and it purpose clearly described. This is very important and should be a precise as possible. Also, remember, you need to have a docstring for every @tool function, otherwise you will get an error. 

Now you can create the graph:

 

# Create the graph
financial_graph = StateGraph(FinanceTeamState)
financial_graph.add_node("DocWriter", doc_writing_node)
financial_graph.add_node("ExaSearch", exa_search_node)
financial_graph.add_node("FinanceSearch", finance_search_node)
financial_graph.add_node("FinanceAnalysis", fundamental_analysis_node)
financial_graph.add_node("TechnicalAnalysis", tech_analysis_node)
financial_graph.add_node("ChartGenerator", chart_node)
financial_graph.add_node("supervisor", investment_analysis_supervisor)

# Add the edges
financial_graph.add_edge("DocWriter", "supervisor")
financial_graph.add_edge("ExaSearch", "supervisor")
financial_graph.add_edge("FinanceSearch", "supervisor")
financial_graph.add_edge("FinanceAnalysis", "supervisor")
financial_graph.add_edge("ChartGenerator", "supervisor")
financial_graph.add_edge("TechnicalAnalysis", "supervisor")

# Add the edges where routing applies
financial_graph.add_conditional_edges(
    "supervisor",
    lambda x: x["next"],
    {
        "DocWriter": "DocWriter",
        "FinanceSearch": "FinanceSearch",
        "FinanceAnalysis": "FinanceAnalysis",
        "TechnicalAnalysis": "TechnicalAnalysis",
        "ChartGenerator": "ChartGenerator",
        "ExaSearch": "ExaSearch",
        "FINISH": END,
    },
)

financial_graph.add_edge(START, "supervisor")
chain = financial_graph.compile()


# The following functions interoperate between the top level graph state
# and the state of the research sub-graph
# this makes it so that the states of each graph don't get intermixed
def enter_chain(message: str, members: List[str]):
    results = {
        "messages": [HumanMessage(content=message)],
        "team_members": ", ".join(members),
    }
    return results

financial_chain = (
    functools.partial(enter_chain, members=financial_graph.nodes)
    | financial_graph.compile()
)

And display it: 

Now, you can just run the prompt as follows:

 

for s in financial_chain.stream(
    """Write an investment report on {{TESLA}} stock for this {{November, 15th 2024}}. Do a technical analysis and
    a fundamental anaylsis on {{TESLA}}. Also draw a line graph of it the historical price on {{TESLA}}. After that,
    write all into an investment
    report and save it to disk as .txt file.""",
    {"recursion_limit": 100},
):
    if "__end__" not in s:
        print(s)
        print("---")

I want to point out two important things you need to consider, you need to have the literals, i. e. variables with two curly brackets: {{}}, this is so that your agents understands that this is a variable. Same holds true for the tools functions as you can see here: 
 

exa_search_agent = create_agent(
    llm,
    [exa_search],
    """As a seasoned investment strategist with 20 years of experience, you weave
    complex financial data into compelling investment narratives,
    your response should clearly articulate the key points you found on {{stock_symbol}} and this {{date}}.""",
)
exa_search_node = functools.partial(agent_node, agent=exa_search_agent, name="ExaSearch")