Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ghcw session 0c2c #8

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Ignore intellij idea files
/.idea/
*iml

github-app.com
# Prevent from pushing the service account key to the repository
service-account.json

Expand All @@ -15,4 +15,5 @@ service-account.json

**/.DS_Store
/azure/backend.tfvars
/azure/.terraform/
/azure/.terraform/
github-app.pem
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ build: build-hugo copy-hugo
go build -o $(BUILD_DIR)/$(BINARY_NAME) .

genkit_mode:
export TEST_MODE=true
genkit start
PORT=4000 TEST_MODE=true genkit start

# Run tests
test:
Expand Down
46 changes: 45 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,48 @@ The integration tests will load prompt inputs from files in the `tests/prompts/`
```
6. Create a new Pull Request.


## GitHub App Integration

### Configuring the GitHub App for Scoring Prompts

To configure the GitHub App for scoring prompts, follow these steps:

1. **Create the GitHub App**:
- Go to GitHub Developer Settings.
- Select "New GitHub App".
- Fill out the basic information, such as the app name, description, and callback URL (use any URL for now; it can be a placeholder if you're just testing locally).
- Set Permissions:
- Repository permissions: Set Contents to Read-only.
- Pull Requests: Set Read & Write (for adding status checks or comments).
- Set Subscribe to events: Select Pull request.
- Register the GitHub App.
- Once registered, you'll get a Client ID and Client Secret. You'll also need to generate a Private Key for authenticating requests from the app.

2. **Build the GitHub App Backend**:
- Implement a server to receive GitHub webhook events and process them.
- The server should:
- Receive Webhook Events: Listen for pull request events.
- Check for Specific Files: When a pull request is opened or updated, check if specific files are present.
- Score the Prompt: Use the `score` package to score the prompt.
- Update PR Status: Post a status check on the pull request based on the result.

3. **Deploy the App**:
- Run Locally: Start by running the app locally and using a tool like ngrok to forward GitHub’s webhook events to your local server.
- Set Environment Variables:
- `GITHUB_WEBHOOK_SECRET`: Your GitHub App webhook secret.
- `GITHUB_APP_TOKEN`: An installation access token for the GitHub App.
- Deployment Options: For production, deploy to a cloud provider like AWS, Heroku, or DigitalOcean. Ensure the server is accessible by GitHub for receiving webhook events.

4. **Install the App on Repositories**:
- Once your app is deployed, install it on the repository you want to monitor.
- When a pull request is created or updated, the app will receive the webhook, check for the specified files, score the prompt, and post a status check result.

### Using the GitHub App

The GitHub App allows you to score prompts and update pull requests with the score. It provides a level of assurance that the updated prompt is still secure whenever you update your prompts.

To use the GitHub App:

1. Install the GitHub App on your repository.
2. Create or update a pull request with the prompt you want to score.
3. The GitHub App will receive the webhook event, score the prompt, and update the pull request with the score and pass/fail status.
23 changes: 20 additions & 3 deletions dependencies/genkit_dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func InitialiseGenkit(ctx context.Context) {
ProjectID: os.Getenv("GCLOUD_PROJECT"),
Location: os.Getenv("GCLOUD_LOCATION"),
}); err != nil {
logger.Log.Fatal("Error initializing Vertex AI", zap.Error(err))
logger.GetLogger().Fatal("Error initializing Vertex AI", zap.Error(err))
}

dotprompt.SetDirectory("prompts")
Expand All @@ -29,15 +29,15 @@ func InitialiseGenkit(ctx context.Context) {
ctx,
googlecloud.Config{ProjectID: os.Getenv("GCLOUD_PROJECT")},
); err != nil {
logger.Log.Fatal("Error initializing Google Cloud", zap.Error(err))
logger.GetLogger().Fatal("Error initializing Google Cloud", zap.Error(err))
}

}

func ProvideModel() ai.Model {
g := vertexai.Model("gemini-1.5-pro")
if g == nil {
logger.Log.Fatal("Model is nil")
logger.GetLogger().Fatal("Model is nil")
}

return g
Expand All @@ -51,3 +51,20 @@ func ProvideReflector() *jsonschema.Reflector {

return r
}

func init() {
// Check if the file 'service-account.json' exists on disk
// Check if the environment variables are set
// Log and exit if the environment variables are not set

if _, err := os.Stat("service-account.json"); os.IsNotExist(err) {
logger.GetLogger().Fatal("service-account.json not found")
}

requiredEnvVars := []string{"GCLOUD_PROJECT", "GCLOUD_LOCATION"}
for _, envVar := range requiredEnvVars {
if os.Getenv(envVar) == "" {
logger.GetLogger().Fatal("Environment variable not set", zap.String("env var", envVar))
}
}
}
4 changes: 2 additions & 2 deletions dependencies/improve_dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func ProvideImprovePrompt(model ai.Model, reflector *jsonschema.Reflector) *dotp
prompt, err := dotprompt.Open("suggest_improvements")

if err != nil {
logger.Log.Error("Error opening suggest_improvements prompt", zap.Error(err))
logger.GetLogger().Error("Error opening suggest_improvements prompt", zap.Error(err))
return nil
}

Expand All @@ -30,7 +30,7 @@ func ProvideImprovePrompt(model ai.Model, reflector *jsonschema.Reflector) *dotp
)

if err != nil {
logger.Log.Error("Error defining suggest_improvements.prompt", zap.Error(err))
logger.GetLogger().Error("Error defining suggest_improvements.prompt", zap.Error(err))
return nil
}
return scoreLlmPrompt
Expand Down
13 changes: 8 additions & 5 deletions dependencies/scoring_dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import (
func ProvideScoringPrompt(model ai.Model, reflector *jsonschema.Reflector) *dotprompt.Prompt {
prompt, err := dotprompt.Open("scoring_prompt")
if err != nil {
logger.Log.Error("Error opening scoring prompt", zap.Error(err))
logger.Log.Fatal("Error opening scoring prompt", zap.Error(err))
logger.GetLogger().Error("Error opening scoring prompt", zap.Error(err))
logger.GetLogger().Fatal("Error opening scoring prompt", zap.Error(err))
}

scoreLlmPrompt, err := dotprompt.Define("scoreLlm.prompt", prompt.TemplateText,
Expand All @@ -28,8 +28,8 @@ func ProvideScoringPrompt(model ai.Model, reflector *jsonschema.Reflector) *dotp
},
)
if err != nil {
logger.Log.Error("Error defining scoreLlm.prompt", zap.Error(err))
logger.Log.Fatal("Error defining scoreLlm.prompt", zap.Error(err))
logger.GetLogger().Error("Error defining scoreLlm.prompt", zap.Error(err))
logger.GetLogger().Fatal("Error defining scoreLlm.prompt", zap.Error(err))
}
return scoreLlmPrompt
}
Expand All @@ -54,7 +54,10 @@ func ProvideScorer(params struct {
})

invokeRequest := func(prompt string) (string, error) {
return scorePromptFlow.Run(context.Background(), prompt)
logger.GetLogger().Info("Invoking score prompt", zap.String("prompt", prompt))
response, err := scorePromptFlow.Run(context.Background(), prompt)
logger.GetLogger().Info("Invoked score prompt", zap.String("response", response), zap.Error(err))
return response, err
}

return score.NewLlmScorer(invokeRequest)
Expand Down
13 changes: 7 additions & 6 deletions endpoints/score.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@
}
if c.Request.Header.Get("Content-Type") == "application/json" {
if err := c.ShouldBindJSON(&requestBody); err != nil {
logger.Log.Error("Invalid JSON payload", zap.Error(err))

logger.GetLogger().Error("Invalid JSON payload", zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON payload"})
return
}
Expand All @@ -109,19 +110,19 @@
cachedResponse, err := promptCache.Get(ctx, prompt+"_score")

if err != nil {
logger.Log.Error("Error getting cached response", zap.Error(err))
c.Redirect(http.StatusPermanentRedirect, "/error")
logger.GetLogger().Error("Failed to get cached response", zap.Error(err))
c.Redirect(http.StatusOK, "/error")
return
}

if cachedResponse != "" {
// Convert cachedResponse to PromptScore
logger.Log.Debug("Using cached response")

Check failure on line 120 in endpoints/score.go

View workflow job for this annotation

GitHub Actions / build-and-test

undefined: logger.Log
var response score.PromptScore
err = json.Unmarshal([]byte(cachedResponse), &response)
if err != nil {
logger.Log.Error("Error unmarshalling cache response", zap.Error(err))
c.Redirect(http.StatusPermanentRedirect, "/error")
logger.GetLogger().Error("Failed to unmarshal cached response", zap.Error(err))
c.Redirect(http.StatusOK, "/error")
return
}

Expand All @@ -137,14 +138,14 @@
return
}

logger.Log.Debug("Scoring prompt", zap.String("prompt", prompt))

Check failure on line 141 in endpoints/score.go

View workflow job for this annotation

GitHub Actions / build-and-test

undefined: logger.Log

response, err := scorer.Score(prompt)

logger.Log.Debug("Prompt scored", zap.Any("response", response))

Check failure on line 145 in endpoints/score.go

View workflow job for this annotation

GitHub Actions / build-and-test

undefined: logger.Log

if err != nil {
logger.Log.Error("Error getting", zap.Error(err))
logger.GetLogger().Error("Failed to score prompt", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to score prompt"})
return
}
Expand All @@ -166,7 +167,7 @@
responseJson, err := json.Marshal(response)

if err != nil {
logger.Log.Error("Error marshalling for cache", zap.Error(err))

Check failure on line 170 in endpoints/score.go

View workflow job for this annotation

GitHub Actions / build-and-test

undefined: logger.Log
}

err = promptCache.Set(context.Background(), prompt, string(responseJson))
Expand All @@ -174,6 +175,6 @@
if err != nil {
// Log the error but ignore - we don't want to fail the request
// if the cache fails
logger.Log.Error("Error caching response", zap.Error(err))

Check failure on line 178 in endpoints/score.go

View workflow job for this annotation

GitHub Actions / build-and-test

undefined: logger.Log
}
}
21 changes: 21 additions & 0 deletions gh/configuration_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package gh

import (
"fmt"

"gopkg.in/yaml.v2"
)

type Config struct {
Prompts []string `yaml:"prompts"`
}

func LoadConfigFromString(content string) (*Config, error) {
var config Config
err := yaml.Unmarshal([]byte(content), &config)
if err != nil {
return nil, fmt.Errorf("failed to decode config content: %w", err)
}

return &config, nil
}
82 changes: 82 additions & 0 deletions gh/file_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package gh

import (
"context"
"fmt"

"PromptDefender-Keep/logger"
"PromptDefender-Keep/score"

"github.com/google/go-github/v66/github"
"go.uber.org/zap"
)

type FileHandler struct {
scorer score.Scorer
client *github.Client
}

func NewFileHandler(scorer score.Scorer, client *github.Client) *FileHandler {
return &FileHandler{
scorer: scorer,
client: client,
}
}

func (fh *FileHandler) ShouldRun(ctx context.Context, owner, repo string, prNumber int, promptFiles []string) (bool, error) {
return true, nil
}

func (fh *FileHandler) RunFilesThroughScoreEndpoint(ctx context.Context, owner, repo, branch string, prNumber int, promptFiles []string) ([]score.PromptScore, error) {
files, _, err := fh.client.PullRequests.ListFiles(ctx, owner, repo, prNumber, nil)
if err != nil {
return nil, fmt.Errorf("error listing files: %w", err)
}

var results []score.PromptScore

for _, file := range files {
isFileInPromptFiles := false

for _, promptFile := range promptFiles {
if file.GetFilename() == promptFile {
isFileInPromptFiles = true
break
}
}

if isFileInPromptFiles == false {
continue
}

logger.GetLogger().Info("Processing file", zap.String("filename", file.GetFilename()), zap.Int("pr_number", prNumber), zap.String("branch", branch))

content, _, _, err := fh.client.Repositories.GetContents(ctx, owner, repo, file.GetFilename(), &github.RepositoryContentGetOptions{
Ref: branch,
})

if err != nil {
return nil, fmt.Errorf("error getting file content: %w", err)
}

if content != nil {
prompt, err := content.GetContent()
if err != nil {
return nil, fmt.Errorf("error getting file content: %w", err)
}

logger.GetLogger().Info("Scoring prompt", zap.String("prompt", prompt))

scoreResult, err := fh.scorer.Score(prompt)
if err != nil {
return nil, fmt.Errorf("error scoring prompt: %w", err)
}

results = append(results, *scoreResult)
} else {
return nil, fmt.Errorf("error getting file content: content is nil")
}
}

return results, nil
}
Loading
Loading