fixed gpt context order, jwt client start
This commit is contained in:
parent
d8ed1c9560
commit
268471058d
85
auth/jwt_client.go
Normal file
85
auth/jwt_client.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
// auth/jwt_client.go
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JWTClient is a struct representing the JWT authentication client.
|
||||||
|
type JWTClient struct {
|
||||||
|
ServerURL string
|
||||||
|
TokenEndpoint string
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJWTClient creates a new JWTClient instance.
|
||||||
|
func NewJWTClient(serverURL, tokenEndpoint, clientID, clientSecret string) *JWTClient {
|
||||||
|
return &JWTClient{
|
||||||
|
ServerURL: serverURL,
|
||||||
|
TokenEndpoint: tokenEndpoint,
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessToken retrieves a JWT access token from the server.
|
||||||
|
func (c *JWTClient) GetAccessToken() (string, error) {
|
||||||
|
// Construct the payload for token request
|
||||||
|
payload := map[string]string{
|
||||||
|
"client_id": c.ClientID,
|
||||||
|
"client_secret": c.ClientSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert payload to JSON
|
||||||
|
payloadJSON, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a POST request to the /auth/tokens endpoint
|
||||||
|
resp, err := http.Post(c.ServerURL+c.TokenEndpoint, "application/json", bytes.NewBuffer(payloadJSON))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
response, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("failed to authenticate: %s", response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the obtained token
|
||||||
|
return string(response), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeRequest makes an authenticated HTTP request using the provided access token.
|
||||||
|
func (c *JWTClient) MakeRequest(apiURL, token string) ([]byte, error) {
|
||||||
|
req, err := http.NewRequest("GET", c.ServerURL+"/auth/tokens/make-request", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
response, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
57
bot/chat.go
57
bot/chat.go
@ -17,21 +17,21 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func populateConversationHistory(session *discordgo.Session, channelID string, conversationHistory []openai.ChatCompletionMessage) []openai.ChatCompletionMessage {
|
func populateConversationHistory(session *discordgo.Session, channelID string, conversationHistory []openai.ChatCompletionMessage) []openai.ChatCompletionMessage {
|
||||||
messages, err := session.ChannelMessages(channelID, 50, "", "", "")
|
messages, err := session.ChannelMessages(channelID, 20, "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error retrieving channel history:", err)
|
log.Error("Error retrieving channel history:", err)
|
||||||
return conversationHistory
|
return conversationHistory
|
||||||
}
|
}
|
||||||
|
|
||||||
totalTokens := 0
|
totalTokens := 0
|
||||||
|
maxHistoryTokens := maxTokens
|
||||||
|
|
||||||
|
// Calculate total tokens without removing any messages
|
||||||
for _, msg := range conversationHistory {
|
for _, msg := range conversationHistory {
|
||||||
totalTokens += len(msg.Content) + len(msg.Role) + 2
|
totalTokens += len(msg.Content) + len(msg.Role) + 2
|
||||||
}
|
}
|
||||||
|
|
||||||
maxHistoryTokens := maxTokens - totalTokens
|
log.Info("Total Tokens Before Trimming:", totalTokens)
|
||||||
if maxHistoryTokens < 0 {
|
|
||||||
maxHistoryTokens = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate from the beginning of conversationHistory (oldest messages)
|
// Iterate from the beginning of conversationHistory (oldest messages)
|
||||||
for i := 0; i < len(conversationHistory); i++ {
|
for i := 0; i < len(conversationHistory); i++ {
|
||||||
@ -40,24 +40,28 @@ func populateConversationHistory(session *discordgo.Session, channelID string, c
|
|||||||
|
|
||||||
if totalTokens-tokens >= maxHistoryTokens {
|
if totalTokens-tokens >= maxHistoryTokens {
|
||||||
// Remove the oldest message
|
// Remove the oldest message
|
||||||
|
log.Info("Removing Oldest Message:", msg.Content)
|
||||||
conversationHistory = conversationHistory[i+1:]
|
conversationHistory = conversationHistory[i+1:]
|
||||||
|
i-- // Adjust index after removal
|
||||||
|
} else {
|
||||||
totalTokens -= tokens
|
totalTokens -= tokens
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Info("Total Tokens After Trimming:", totalTokens)
|
||||||
|
|
||||||
// Add new messages from the channel
|
// Add new messages from the channel
|
||||||
addedTokens := 0
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
for _, message := range messages {
|
message := messages[i]
|
||||||
if len(message.Content) > 0 {
|
if len(message.Content) > 0 {
|
||||||
tokens := len(message.Content) + 2 // Account for role and content tokens
|
tokens := len(message.Content) + 2 // Account for role and content tokens
|
||||||
if totalTokens+tokens <= maxContextTokens && addedTokens+tokens <= maxContextTokens {
|
if totalTokens+tokens <= maxContextTokens {
|
||||||
conversationHistory = append(conversationHistory, openai.ChatCompletionMessage{
|
conversationHistory = append(conversationHistory, openai.ChatCompletionMessage{
|
||||||
Role: openai.ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: message.Content,
|
Content: message.Content,
|
||||||
})
|
})
|
||||||
totalTokens += tokens
|
totalTokens += tokens
|
||||||
addedTokens += tokens
|
log.Info("Adding New Message:", message.Content)
|
||||||
} else {
|
} else {
|
||||||
if totalTokens+tokens > maxContextTokens {
|
if totalTokens+tokens > maxContextTokens {
|
||||||
log.Warn("Message token count exceeds maxContextTokens:", len(message.Content), len(message.Content)+2)
|
log.Warn("Message token count exceeds maxContextTokens:", len(message.Content), len(message.Content)+2)
|
||||||
@ -69,38 +73,15 @@ func populateConversationHistory(session *discordgo.Session, channelID string, c
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log the final order of conversation history
|
||||||
|
log.Info("Final Conversation History Order:", conversationHistory)
|
||||||
|
|
||||||
return conversationHistory
|
return conversationHistory
|
||||||
}
|
}
|
||||||
|
|
||||||
func chatGPT(session *discordgo.Session, channelID string, message string, conversationHistory []openai.ChatCompletionMessage) {
|
func chatGPT(session *discordgo.Session, channelID string, message string, conversationHistory []openai.ChatCompletionMessage) {
|
||||||
client := openai.NewClient(OpenAIToken)
|
client := openai.NewClient(OpenAIToken)
|
||||||
|
|
||||||
// Trim conversation history if it exceeds maxContextTokens
|
|
||||||
// totalTokens := 0
|
|
||||||
// trimmedMessages := []openai.ChatCompletionMessage{}
|
|
||||||
|
|
||||||
// for i := len(conversationHistory) - 1; i >= 0; i-- {
|
|
||||||
// msg := conversationHistory[i]
|
|
||||||
// tokens := len(msg.Content) + len(msg.Role) + 2 // Account for role and content tokens
|
|
||||||
|
|
||||||
// if totalTokens+tokens <= maxContextTokens {
|
|
||||||
// trimmedMessages = append([]openai.ChatCompletionMessage{msg}, trimmedMessages...)
|
|
||||||
// totalTokens += tokens
|
|
||||||
// } else {
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Update conversationHistory with trimmed conversation history
|
|
||||||
//conversationHistory = trimmedMessages
|
|
||||||
|
|
||||||
// Add user message to conversation history
|
|
||||||
//userMessage := openai.ChatCompletionMessage{
|
|
||||||
// Role: openai.ChatMessageRoleUser,
|
|
||||||
// Content: message,
|
|
||||||
//}
|
|
||||||
//conversationHistory = append(conversationHistory, userMessage)
|
|
||||||
|
|
||||||
// Perform GPT-4 completion
|
// Perform GPT-4 completion
|
||||||
log.Info("Starting completion...", conversationHistory)
|
log.Info("Starting completion...", conversationHistory)
|
||||||
resp, err := client.CreateChatCompletion(
|
resp, err := client.CreateChatCompletion(
|
||||||
@ -115,8 +96,6 @@ func chatGPT(session *discordgo.Session, channelID string, message string, conve
|
|||||||
)
|
)
|
||||||
log.Info("completion done.")
|
log.Info("completion done.")
|
||||||
|
|
||||||
// ...
|
|
||||||
|
|
||||||
// Handle API errors
|
// Handle API errors
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error connecting to the OpenAI API:", err)
|
log.Error("Error connecting to the OpenAI API:", err)
|
||||||
@ -137,8 +116,6 @@ func chatGPT(session *discordgo.Session, channelID string, message string, conve
|
|||||||
pages = append(pages, gptResponse[i:end])
|
pages = append(pages, gptResponse[i:end])
|
||||||
}
|
}
|
||||||
|
|
||||||
// ...
|
|
||||||
|
|
||||||
// Send the first page
|
// Send the first page
|
||||||
currentPage := 0
|
currentPage := 0
|
||||||
totalPages := len(pages)
|
totalPages := len(pages)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user