Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/agents/intern/generators/code_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Dict, List
import dspy
import json

from src.agents.intern.generators.signatures import (
NewFilesGeneratorSignature,
RelevantFileSelectionSignature,
)
from src.models import Codebase, Ticket


SHOULD_RECORD_INPUT_OUTPUT = True


class CodeGenerator(dspy.Module):
def __init__(self):
super().__init__()

self.relevant_file_selector = dspy.TypedChainOfThought(
RelevantFileSelectionSignature
)
self.new_files_generator = dspy.TypedChainOfThought(NewFilesGeneratorSignature)

def record_input_output(self, inputs: Dict, outputs: Dict):
with open("tests/data/code_generator.json", "w") as f:
json.dump(
{
"inputs": inputs,
"outputs": outputs,
},
f,
)

def forward(self, codebase: Codebase, ticket: Ticket):
relevant_files = self.relevant_file_selector(
files_in_codebase=json.dumps(list(codebase.files.keys())),
ticket=json.dumps(ticket.model_dump()),
)

subset_codebase = {
file: codebase.files[file] for file in relevant_files.relevant_files
}

relevant_codebase = Codebase(files=subset_codebase)

new_files = self.new_files_generator(
relevant_codebase=json.dumps(relevant_codebase.model_dump()),
ticket=json.dumps(ticket.model_dump()),
)

if SHOULD_RECORD_INPUT_OUTPUT:
self.record_input_output(
{
"codebase": codebase.model_dump(),
"ticket": ticket.model_dump(),
},
{
"relevant_files": relevant_files.relevant_files,
"new_files": new_files.new_files,
"explanations": new_files.explanations,
},
)

return {
"relevant_files": relevant_files.relevant_files,
"new_files": new_files.new_files,
"explanations": new_files.explanations,
}
62 changes: 0 additions & 62 deletions src/agents/intern/generators/diff_generator.py

This file was deleted.

22 changes: 22 additions & 0 deletions src/agents/intern/generators/signatures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import dspy
from typing import Dict, List


class RelevantFileSelectionSignature(dspy.Signature):
files_in_codebase = dspy.InputField()
ticket = dspy.InputField()
relevant_files: List[str] = dspy.OutputField(
desc="Give the relevant files for you to observe to complete the ticket. They must be keys of the files_in_codebase dict."
)


# Define the agent
class NewFilesGeneratorSignature(dspy.Signature):
relevant_codebase = dspy.InputField()
ticket = dspy.InputField()
new_files: Dict[str, str] = dspy.OutputField(
desc="Generate the entire files that need to be update or created complete the ticket, with all of their content post update. The key is the path of the file and the value is the content of the file."
)
explanations = dspy.OutputField(
desc="Give explanations for the new files generated. Use Markdown to format the text."
)
10 changes: 5 additions & 5 deletions src/agents/intern/processors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.agents.intern.generators.diff_generator import DiffGenerator
from src.language_models import gpt4, mistral
from src.agents.intern.generators.code_generator import CodeGenerator
from src.language_models import gpt4
import dspy

from src.models import Codebase, Ticket
Expand All @@ -22,8 +22,8 @@ def generate_code_change(ticket: Ticket, code_base: Codebase):
# and will return a new code_change
dspy.configure(lm=gpt4)

diff_generator = DiffGenerator()
diff_generator = CodeGenerator()

new_files, explanations = diff_generator(code_base, ticket)
res = diff_generator(code_base, ticket)

return new_files, explanations
return res["new_files"], res["explanations"]
4 changes: 4 additions & 0 deletions src/agents/reviewer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def process_pr(self):
# Get first open PR from GH that hasn't been approved yet
pr = self.pr_backlog.pop(0)

if pr.ticket_id is None:
# This PR is not associated with a Trello ticket
return

# Fetch the Trello ticket that corresponds to this PR
ticket = self.board_helper.get_ticket(ticket_id=pr.ticket_id)

Expand Down
21 changes: 14 additions & 7 deletions src/helpers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,26 @@ def push_changes(
)

def get_entire_codebase(self) -> Codebase:
codebase_dict = {}
contents = self.repo.get_contents("")
if not isinstance(contents, list):
contents = [contents]

codebase_dict = {}
for file in contents:
try:
codebase_dict[file.path] = file.decoded_content.decode("utf-8")
except Exception as e:
pass
def process_contents(contents, path=""):
for item in contents:
if item.type == "dir":
dir_contents = self.repo.get_contents(item.path)
process_contents(dir_contents, path + item.name + "/")
elif item.type == "file":
try:
codebase_dict[path + item.name] = item.decoded_content.decode(
"utf-8"
)
except Exception as e:
pass

process_contents(contents)
codebase = Codebase(files=codebase_dict)

return codebase

def get_file_content(self, file):
Expand Down
30 changes: 30 additions & 0 deletions tests/code_generator.test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Import path above
import sys
import os

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

from src.agents.intern.generators.code_generator import CodeGenerator
from src.models import Codebase, Ticket
from src.language_models import gpt4
import dspy
import json
from dotenv import load_dotenv


if __name__ == "__main__":
load_dotenv()

dspy.configure(lm=gpt4)

code_generator = CodeGenerator()

# Import the ticket and codease from a json file in data/code_generator.json
with open("tests/data/code_generator.json", "r") as f:
data = json.load(f)

codebase = Codebase(**data["inputs"]["codebase"])
ticket = Ticket(**data["inputs"]["ticket"])

res = code_generator(codebase, ticket)
print(res)