Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions api/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ package api
import (
"context"
"fmt"
"net/url"
"strconv"
"strings"
"time"

comms "bridgerton.audius.co/api/comms"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/gofiber/fiber/v2"
Expand Down Expand Up @@ -172,3 +176,90 @@ func (app *ApiServer) getUserIDFromWallet(ctx context.Context, wallet string) (i
app.resolveWalletCache.Set(key, userId)
return userId, nil
}

/*
* Parses query string for a signed comms GET request and returns the userId
associated with the signing wallet
*/
func (app *ApiServer) userIdForSignedCommsRequest(c *fiber.Ctx) (int, error) {
if c.Method() != "GET" {
return 0, fiber.NewError(fiber.StatusBadRequest, "readSignedGet: bad method: "+c.Method())
}

sigBase64 := c.Get(comms.SigHeader)

// for websocket request, read from query param instead of header
if querySig := c.Query("signature"); sigBase64 == "" && querySig != "" {
sigBase64 = querySig
}

// Check that timestamp is not too old
timestamp, err := strconv.ParseInt(c.Query("timestamp"), 0, 64)
if err != nil {
return 0, fiber.NewError(fiber.StatusBadRequest, "failed to parse timestamp: "+err.Error())
}

tsAge := time.Now().UnixMilli() - timestamp
if tsAge < 0 {
tsAge *= -1
}
if tsAge > comms.SignatureTimeToLiveMs {
return 0, fiber.NewError(fiber.StatusBadRequest, "timestamp not current")
}

// Strip out app_name,api_key,signature to get the parameters that are actually used to generate the signature
uri := c.Request().URI()
path := string(uri.Path())
query := string(uri.QueryString())

queryParams, err := url.ParseQuery(query)
if err != nil {
return 0, fiber.NewError(fiber.StatusBadRequest, "failed to parse query parameters: "+err.Error())
}

queryParams.Del("app_name")
queryParams.Del("api_key")
queryParams.Del("signature")

// Build the final URL string
urlStr := path
if len(queryParams) > 0 {
urlStr += "?" + queryParams.Encode()
}

payload := []byte(urlStr)

wallet, pubkey, err := comms.RecoverSigningWallet(sigBase64, payload)
if err != nil {
return 0, fiber.NewError(fiber.StatusBadRequest, "failed to recoverSigningWallet: "+err.Error())
}
userId, err := app.getUserIDFromWallet(c.Context(), wallet)
if err != nil {
return 0, err
}

app.commsRpcProcessor.SetPubkeyForUser(int32(userId), pubkey)

return userId, nil
}

func (app *ApiServer) readSignedCommsPostRequest(c *fiber.Ctx) ([]byte, string, int, error) {
if c.Method() != "POST" {
return nil, "", 0, fiber.NewError(fiber.StatusBadRequest, "readSignedPost bad method: "+c.Method())
}

payload := c.Body()

sigHex := c.Get(comms.SigHeader)
wallet, pubkey, err := comms.RecoverSigningWallet(sigHex, payload)
if err != nil {
return nil, "", 0, err
}
userId, err := app.getUserIDFromWallet(c.Context(), wallet)
if err != nil {
return nil, "", 0, err
}

app.commsRpcProcessor.SetPubkeyForUser(int32(userId), pubkey)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

return payload, wallet, userId, nil
}
24 changes: 12 additions & 12 deletions api/comms/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func chatCreate(db dbv1.DBTX, ctx context.Context, userId int32, ts time.Time, p
($1, $2, $2)
on conflict (chat_id)
do update set created_at = $2, last_message_at = $2 where chat.created_at > $2
`, params.ChatID, ts)
`, params.ChatID, ts.UTC())
if err != nil {
return err
}
Expand All @@ -63,7 +63,7 @@ func chatCreate(db dbv1.DBTX, ctx context.Context, userId int32, ts time.Time, p
($1, $2, $3, $4, $5)
on conflict (chat_id, user_id)
do update set invited_by_user_id=$2, invite_code=$3, created_at=$5 where chat_member.created_at > $5`,
params.ChatID, userId, invite.InviteCode, invitedUserId, ts)
params.ChatID, userId, invite.InviteCode, invitedUserId, ts.UTC())
if err != nil {
return err
}
Expand Down Expand Up @@ -94,7 +94,7 @@ func chatCreate(db dbv1.DBTX, ctx context.Context, userId int32, ts time.Time, p
}

func chatDelete(db dbv1.DBTX, ctx context.Context, userId int32, chatId string, messageTimestamp time.Time) error {
_, err := db.Exec(ctx, "update chat_member set cleared_history_at = $1, last_active_at = $1, unread_count = 0, is_hidden = true where chat_id = $2 and user_id = $3", messageTimestamp, chatId, userId)
_, err := db.Exec(ctx, "update chat_member set cleared_history_at = $1, last_active_at = $1, unread_count = 0, is_hidden = true where chat_id = $2 and user_id = $3", messageTimestamp.UTC(), chatId, userId)
return err
}

Expand Down Expand Up @@ -168,7 +168,7 @@ func chatSendMessage(db dbv1.DBTX, ctx context.Context, userId int32, chatId str
var err error

_, err = db.Exec(ctx, "insert into chat_message (message_id, chat_id, user_id, created_at, ciphertext) values ($1, $2, $3, $4, $5)",
messageId, chatId, userId, messageTimestamp, ciphertext)
messageId, chatId, userId, messageTimestamp.UTC(), ciphertext)
if err != nil {
return err
}
Expand Down Expand Up @@ -198,9 +198,9 @@ func chatReactMessage(db dbv1.DBTX, ctx context.Context, userId int32, chatId st
($1, $2, $3, $4, $4)
on conflict (user_id, message_id)
do update set reaction = $3, updated_at = $4 where chat_message_reactions.updated_at < $4`,
userId, messageId, *reaction, messageTimestamp)
userId, messageId, *reaction, messageTimestamp.UTC())
} else {
_, err = db.Exec(ctx, "delete from chat_message_reactions where user_id = $1 and message_id = $2 and updated_at < $3", userId, messageId, messageTimestamp)
_, err = db.Exec(ctx, "delete from chat_message_reactions where user_id = $1 and message_id = $2 and updated_at < $3", userId, messageId, messageTimestamp.UTC())
}
if err != nil {
return err
Expand All @@ -213,7 +213,7 @@ func chatReactMessage(db dbv1.DBTX, ctx context.Context, userId int32, chatId st

func chatReadMessages(db dbv1.DBTX, ctx context.Context, userId int32, chatId string, readTimestamp time.Time) error {
_, err := db.Exec(ctx, "update chat_member set unread_count = 0, last_active_at = $1 where chat_id = $2 and user_id = $3",
readTimestamp, chatId, userId)
readTimestamp.UTC(), chatId, userId)
return err
}

Expand Down Expand Up @@ -241,7 +241,7 @@ func updatePermissions(db dbv1.DBTX, ctx context.Context, userId int32, permit C
values ($1, $2, $3, $4)
on conflict (user_id, permits)
do update set allowed = $3 where chat_permissions.updated_at < $4
`, userId, permit, permitAllowed, messageTimestamp)
`, userId, permit, permitAllowed, messageTimestamp.UTC())
return err
}

Expand All @@ -251,7 +251,7 @@ func chatSetPermissions(db dbv1.DBTX, ctx context.Context, userId int32, permits
if allow == nil || permits == ChatPermissionAll || permits == ChatPermissionNone || isInPermitList(ChatPermissionAll, permitList) || isInPermitList(ChatPermissionNone, permitList) {
_, err := db.Exec(ctx, `
delete from chat_permissions where user_id = $1 and updated_at < $2
`, userId, messageTimestamp)
`, userId, messageTimestamp.UTC())
if err != nil {
return err
}
Expand All @@ -263,7 +263,7 @@ func chatSetPermissions(db dbv1.DBTX, ctx context.Context, userId int32, permits
_, err := db.Exec(ctx, `
insert into chat_permissions (user_id, permits, updated_at)
values ($1, $2, $3)
on conflict do nothing`, userId, permits, messageTimestamp)
on conflict do nothing`, userId, permits, messageTimestamp.UTC())
return err
}

Expand All @@ -288,12 +288,12 @@ func chatSetPermissions(db dbv1.DBTX, ctx context.Context, userId int32, permits
}

func chatBlock(db dbv1.DBTX, ctx context.Context, userId int32, blockeeUserId int32, messageTimestamp time.Time) error {
_, err := db.Exec(ctx, "insert into chat_blocked_users (blocker_user_id, blockee_user_id, created_at) values ($1, $2, $3) on conflict do nothing", userId, blockeeUserId, messageTimestamp)
_, err := db.Exec(ctx, "insert into chat_blocked_users (blocker_user_id, blockee_user_id, created_at) values ($1, $2, $3) on conflict do nothing", userId, blockeeUserId, messageTimestamp.UTC())
return err
}

func chatUnblock(db dbv1.DBTX, ctx context.Context, userId int32, unblockedUserId int32, messageTimestamp time.Time) error {
_, err := db.Exec(ctx, "delete from chat_blocked_users where blocker_user_id = $1 and blockee_user_id = $2 and created_at < $3", userId, unblockedUserId, messageTimestamp)
_, err := db.Exec(ctx, "delete from chat_blocked_users where blocker_user_id = $1 and blockee_user_id = $2 and created_at < $3", userId, unblockedUserId, messageTimestamp.UTC())
return err
}

Expand Down
4 changes: 2 additions & 2 deletions api/comms/chat_blast.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func chatBlast(db dbv1.DBTX, ctx context.Context, userId int32, ts time.Time, pa
($1, $2, $3, $4, $5, $6, $7)
on conflict (blast_id)
do nothing
`, params.BlastID, userId, params.Audience, params.AudienceContentType, audienceContentID, params.Message, ts)
`, params.BlastID, userId, params.Audience, params.AudienceContentType, audienceContentID, params.Message, ts.UTC())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -72,7 +72,7 @@ func chatBlast(db dbv1.DBTX, ctx context.Context, userId int32, ts time.Time, pa
SELECT chat_id FROM targ;
`

rows, err := db.Query(ctx, fanOutSql, params.BlastID, ts)
rows, err := db.Query(ctx, fanOutSql, params.BlastID, ts.UTC())
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion api/comms/chat_block_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestChatBlocking(t *testing.T) {
user2Id := seededRand.Int31()

assertBlocked := func(blockerUserId int32, blockeeUserId int32, timestamp time.Time, expected int) {
row := pool.QueryRow(ctx, "select count(*) from chat_blocked_users where blocker_user_id = $1 and blockee_user_id = $2 and created_at = $3", blockerUserId, blockeeUserId, timestamp)
row := pool.QueryRow(ctx, "select count(*) from chat_blocked_users where blocker_user_id = $1 and blockee_user_id = $2 and created_at = $3", blockerUserId, blockeeUserId, timestamp.UTC())
var count int
err := row.Scan(&count)
assert.NoError(t, err)
Expand Down
Loading