This commit is contained in:
user
2024-03-20 21:23:30 -05:00
parent b0443f6233
commit fb6614f049
7 changed files with 58 additions and 126 deletions

View File

@@ -1,18 +1,24 @@
package controller
import (
"context"
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"github.com/gin-gonic/gin"
openai "github.com/sashabaranov/go-openai"
)
type ChatRequest struct {
Message string `json:"message"`
}
// Response struct to unmarshal the JSON response
type Response struct {
Response string `json:"response"`
}
// GeneralOpenAI godoc
//
// @Summary Gerneral ChatGPT
@@ -33,7 +39,7 @@ func (c *Controller) GeneralOpenAI(ctx *gin.Context) {
req.Message = ctx.Query("message")
}
result, err := c.createChatCompletion(req.Message)
result, err := c.createChatCompletion(req.Message, "openchat")
if err != nil {
err := ctx.AbortWithError(http.StatusInternalServerError, err)
if err != nil {
@@ -63,7 +69,7 @@ func (c *Controller) TravelAgentOpenAI(ctx *gin.Context) {
req.Message = "I want you to act as a travel guide. I will give you my location and you will give me suggestions. " + req.Message
result, err := c.createChatCompletion(req.Message)
result, err := c.createChatCompletion(req.Message, "openchat")
if err != nil {
err := ctx.AbortWithError(http.StatusInternalServerError, err)
if err != nil {
@@ -74,23 +80,43 @@ func (c *Controller) TravelAgentOpenAI(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"message": result})
}
func (c *Controller) createChatCompletion(message string) (string, error) {
client := c.Cfg.OpenaiClient
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: message,
},
},
},
)
if err != nil {
return "", err
func (c *Controller) createChatCompletion(message string, model string) (string, error) {
// Define the request body
requestBody := map[string]interface{}{
"model": model,
"prompt": message,
"stream": false,
}
return resp.Choices[0].Message.Content, nil
// Convert the request body to JSON
requestBodyBytes, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("error encoding request body: %v", err)
}
// Send a POST request to the specified URL with the request body
response, err := http.Post(
"http://"+c.Cfg.LlamaURL+"/api/generate",
"application/json",
bytes.NewBuffer(requestBodyBytes),
)
if err != nil {
return "", fmt.Errorf("error sending POST request: %v", err)
}
defer response.Body.Close()
// Read the response body
responseBody, err := ioutil.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("error reading response body: %v", err)
}
// Unmarshal the JSON response
var resp Response
if err := json.Unmarshal(responseBody, &resp); err != nil {
return "", fmt.Errorf("error decoding response body: %v", err)
}
// Return the response
return resp.Response, nil
}