diff --git a/api/comms/chat.go b/api/comms/chat.go index a0ffdbc5..3595842e 100644 --- a/api/comms/chat.go +++ b/api/comms/chat.go @@ -71,19 +71,24 @@ func chatCreate(db dbv1.DBTX, ctx context.Context, userId int32, ts time.Time, p } for _, blast := range blasts { + messageId := trashid.BlastMessageID(blast.BlastID, params.ChatID) + _, err = db.Exec(ctx, ` insert into chat_message (message_id, chat_id, user_id, created_at, blast_id) values ($1, $2, $3, $4, $5) on conflict do nothing - `, trashid.BlastMessageID(blast.BlastID, params.ChatID), params.ChatID, blast.FromUserID, blast.CreatedAt, blast.BlastID) + `, messageId, params.ChatID, blast.FromUserID, blast.CreatedAt.UTC(), blast.BlastID) if err != nil { return err } } err = chatUpdateLatestFields(db, ctx, params.ChatID) + if err != nil { + return err + } return err } @@ -317,37 +322,37 @@ func getNewBlasts(tx dbv1.DBTX, ctx context.Context, arg getNewBlastsParams) ([] // see also: subtly different inverse query exists in chat_blast.go // to fan out messages to existing chat var findNewBlasts = ` - with - last_permission_change as ( - select max(t) as t from ( - select updated_at as t from chat_permissions where user_id = $1 - union - select created_at as t from chat_blocked_users where blocker_user_id = $1 - union - select to_timestamp(0) - ) as timestamp_subquery + WITH + last_permission_change AS ( + SELECT max(t) AS t FROM ( + SELECT updated_at AS t FROM chat_permissions WHERE user_id = @user_id + UNION + SELECT created_at AS t FROM chat_blocked_users WHERE blocker_user_id = @user_id + UNION + SELECT to_timestamp(0) + ) AS timestamp_subquery ), - all_new as ( - select * - from chat_blast blast - where - from_user_id in ( + all_new AS ( + SELECT * + FROM chat_blast blast + WHERE + from_user_id IN ( -- follower_audience SELECT followee_user_id AS from_user_id FROM follows WHERE blast.audience = 'follower_audience' AND follows.followee_user_id = blast.from_user_id - AND follows.follower_user_id = $1 + AND follows.follower_user_id = @user_id AND follows.is_delete = false AND follows.created_at < blast.created_at ) - OR from_user_id in ( + OR from_user_id IN ( -- tipper_audience SELECT receiver_user_id FROM user_tips tip WHERE blast.audience = 'tipper_audience' AND receiver_user_id = blast.from_user_id - AND sender_user_id = $1 + AND sender_user_id = @user_id AND tip.created_at < blast.created_at ) OR from_user_id IN ( @@ -358,7 +363,7 @@ func getNewBlasts(tx dbv1.DBTX, ctx context.Context, arg getNewBlastsParams) ([] JOIN tracks og ON remixes.parent_track_id = og.track_id WHERE blast.audience = 'remixer_audience' AND og.owner_id = blast.from_user_id - AND t.owner_id = $1 + AND t.owner_id = @user_id AND ( blast.audience_content_id IS NULL OR ( @@ -366,6 +371,7 @@ func getNewBlasts(tx dbv1.DBTX, ctx context.Context, arg getNewBlastsParams) ([] AND blast.audience_content_id = og.track_id ) ) + AND t.created_at < blast.created_at ) OR from_user_id IN ( -- customer_audience @@ -373,7 +379,7 @@ func getNewBlasts(tx dbv1.DBTX, ctx context.Context, arg getNewBlastsParams) ([] FROM usdc_purchases p WHERE blast.audience = 'customer_audience' AND p.seller_user_id = blast.from_user_id - AND p.buyer_user_id = $1 + AND p.buyer_user_id = @user_id AND ( audience_content_id IS NULL OR ( @@ -381,6 +387,7 @@ func getNewBlasts(tx dbv1.DBTX, ctx context.Context, arg getNewBlastsParams) ([] AND blast.audience_content_id = p.content_id ) ) + AND p.created_at < blast.created_at ) OR from_user_id IN ( -- coin_holder_audience via sol_user_balances @@ -389,19 +396,21 @@ func getNewBlasts(tx dbv1.DBTX, ctx context.Context, arg getNewBlastsParams) ([] JOIN sol_user_balances sub ON sub.mint = ac.mint WHERE blast.audience = 'coin_holder_audience' AND ac.user_id = blast.from_user_id - AND sub.user_id = $1 + AND sub.user_id = @user_id AND sub.balance > 0 -- TODO: PE-6663 This isn't entirely correct yet, need to check "time of most recent membership" AND sub.created_at < blast.created_at ) ) - select * from all_new - where created_at > (select t from last_permission_change) - and chat_allowed(from_user_id, $1) - order by created_at - ` - - rows, err := tx.Query(ctx, findNewBlasts, arg.UserID) + SELECT * FROM all_new + WHERE created_at > (select t from last_permission_change) + AND chat_allowed(from_user_id, @user_id) + ORDER BY created_at + ;` + + rows, err := tx.Query(ctx, findNewBlasts, pgx.NamedArgs{ + "user_id": arg.UserID, + }) if err != nil { return nil, err } diff --git a/api/comms/chat_blast.go b/api/comms/chat_blast.go index 96000b06..3aea22b4 100644 --- a/api/comms/chat_blast.go +++ b/api/comms/chat_blast.go @@ -11,8 +11,7 @@ import ( // Result struct to hold chat_id and to_user_id type ChatBlastResult struct { - ChatID string `db:"chat_id"` - ToUserID int32 `db:"to_user_id"` + ChatID string `db:"chat_id"` } type OutgoingChatMessage struct { @@ -80,11 +79,7 @@ func chatBlast(db dbv1.DBTX, ctx context.Context, userId int32, ts time.Time, pa defer rows.Close() // Scan the results into the results slice - results, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (ChatBlastResult, error) { - var result ChatBlastResult - err := row.Scan(&result.ChatID, &result.ToUserID) - return result, err - }) + results, err = pgx.CollectRows(rows, pgx.RowToStructByName[ChatBlastResult]) if err != nil { return nil, err } diff --git a/api/comms/chat_blast_test.go b/api/comms/chat_blast_test.go new file mode 100644 index 00000000..a2e113fa --- /dev/null +++ b/api/comms/chat_blast_test.go @@ -0,0 +1,1009 @@ +package comms + +import ( + "context" + "fmt" + "testing" + "time" + + "bridgerton.audius.co/database" + "bridgerton.audius.co/trashid" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/assert" +) + +/* + Note: There is some overlap between these tests and those in comms_blasts_test.go + These tests are meant to exercise the write path. +*/ + +func mustGetMessagesAndReactions(t *testing.T, pool *pgxpool.Pool, ctx context.Context, userID int32, chatID string) []chatMessageAndReactionsRow { + messages, err := getChatMessagesAndReactions(pool, ctx, chatMessagesAndReactionsParams{ + UserID: userID, + ChatID: chatID, + Limit: 10, + Before: time.Now().Add(time.Hour * 2).UTC(), + After: time.Now().Add(time.Hour * -2).UTC(), + }) + assert.NoError(t, err) + return messages +} + +func TestChatBlastFollowers(t *testing.T) { + t0 := time.Now().Add(time.Second * -100).UTC() + t1 := time.Now().Add(time.Second * -90).UTC() + t2 := time.Now().Add(time.Second * -80).UTC() + t3 := time.Now().Add(time.Second * -70).UTC() + t4 := time.Now().Add(time.Second * -60).UTC() + t5 := time.Now().Add(time.Second * -50).UTC() + t6 := time.Now().Add(time.Second * -40).UTC() + + pool := database.CreateTestDatabase(t, "test_comms") + defer pool.Close() + database.Seed(pool, database.FixtureMap{ + "users": { + {"user_id": 68, "wallet": "wallet68", "handle": "user68"}, + {"user_id": 1, "wallet": "wallet1", "handle": "user1"}, + {"user_id": 100, "wallet": "wallet100", "handle": "user100"}, + {"user_id": 101, "wallet": "wallet101", "handle": "user101"}, + {"user_id": 102, "wallet": "wallet102", "handle": "user102"}, + {"user_id": 103, "wallet": "wallet103", "handle": "user103"}, + {"user_id": 104, "wallet": "wallet104", "handle": "user104"}, + }, + "follows": {{ + "follower_user_id": 68, + "followee_user_id": 1, + "created_at": t0, + }, { + "follower_user_id": 1, + "followee_user_id": 68, + "created_at": t0, + }, + { + "follower_user_id": 100, + "followee_user_id": 1, + "created_at": t0, + }, + { + "follower_user_id": 101, + "followee_user_id": 1, + "created_at": t0, + }, + { + "follower_user_id": 102, + "followee_user_id": 1, + "created_at": t0, + }, + { + "follower_user_id": 103, + "followee_user_id": 1, + "created_at": t0, + }, + { + "follower_user_id": 104, + "followee_user_id": 1, + "created_at": t0, + }, + }, + }) + validator := CreateTestValidator(t, pool, DefaultRateLimitConfig, DefaultTestValidatorConfig) + + ctx := context.Background() + + var count = 0 + var messages []chatMessageAndReactionsRow + + // Blaster (user 1) closes inbox + // But recipients should still be able to upgrade. + { + err := chatSetPermissions(pool, ctx, 1, ChatPermissionNone, nil, nil, t0) + assert.NoError(t, err) + } + + // Other user (104) closes inbox + { + err := chatSetPermissions(pool, ctx, 104, ChatPermissionNone, nil, nil, t0) + assert.NoError(t, err) + } + + // ----------------- some threads already exist ------------- + // user 100 starts a thread with 1 before first blast + chatId_100_1 := trashid.ChatID(100, 1) + chatId_1_103 := trashid.ChatID(1, 103) + { + err := chatCreate(pool, ctx, 100, t1, ChatCreateRPCParams{ + ChatID: chatId_100_1, + Invites: []PurpleInvite{ + {UserID: trashid.MustEncodeHashID(100), InviteCode: "x"}, + {UserID: trashid.MustEncodeHashID(1), InviteCode: "x"}, + }, + }) + assert.NoError(t, err) + + // send a message in chat + err = chatSendMessage(pool, ctx, 100, chatId_100_1, "pre1", t1, "100 here sending 1 a message") + assert.NoError(t, err) + + messages = mustGetMessagesAndReactions(t, pool, ctx, 100, chatId_100_1) + assert.Len(t, messages, 1) + assert.False(t, messages[0].IsPlaintext) + + messages = mustGetMessagesAndReactions(t, pool, ctx, 1, chatId_100_1) + assert.Len(t, messages, 1) + + ch, err := getUserChat(pool, ctx, chatMembershipParams{ + UserID: 1, + ChatID: chatId_100_1, + }) + assert.NoError(t, err) + assert.False(t, ch.LastMessageIsPlaintext) + + // user 1 now has 1 (real) chats + chats, err := getUserChats(pool, ctx, userChatsParams{ + UserID: 1, + Limit: 10, + Before: time.Now().Add(time.Hour * 2).UTC(), + After: time.Now().Add(time.Hour * -2).UTC(), + }) + assert.NoError(t, err) + assert.Len(t, chats, 1) + } + + // user 1 starts empty thread with 103 before first blast + { + err := chatCreate(pool, ctx, 1, t1, ChatCreateRPCParams{ + ChatID: chatId_1_103, + Invites: []PurpleInvite{ + {UserID: trashid.MustEncodeHashID(1), InviteCode: "x"}, + {UserID: trashid.MustEncodeHashID(103), InviteCode: "x"}, + }, + }) + assert.NoError(t, err) + + // user 1 still has 1 (real) chats + // because this is empty + chats, err := getUserChats(pool, ctx, userChatsParams{ + UserID: 1, + Limit: 10, + Before: time.Now().Add(time.Hour * 2).UTC(), + After: time.Now().Add(time.Hour * -2).UTC(), + }) + assert.NoError(t, err) + assert.Len(t, chats, 1) + } + + // ----------------- a first blast ------------------------ + chatId_101_1 := trashid.ChatID(101, 1) + + outgoingMessages, err := chatBlast(pool, ctx, 1, t2, ChatBlastRPCParams{ + BlastID: "b1", + Audience: FollowerAudience, + Message: "what up fam", + }) + assert.NoError(t, err) + + // Test that outgoing messages contain the audience field + for _, outgoingMsg := range outgoingMessages { + assert.NotNil(t, outgoingMsg.ChatMessageRPC.Params.Audience, "Audience should be set in outgoing message") + assert.Equal(t, FollowerAudience, *outgoingMsg.ChatMessageRPC.Params.Audience, "Audience should match the blast audience") + } + + pool.QueryRow(ctx, `select count(*) from chat_blast`).Scan(&count) + assert.Equal(t, 1, count) + + pool.QueryRow(ctx, `select count(*) from chat where chat_id = $1`, chatId_101_1).Scan(&count) + assert.Equal(t, 0, count) + + pool.QueryRow(ctx, `select count(*) from chat_member where chat_id = $1`, chatId_101_1).Scan(&count) + assert.Equal(t, 0, count) + + pool.QueryRow(ctx, `select count(*) from chat_message where chat_id = $1`, chatId_101_1).Scan(&count) + assert.Equal(t, 0, count) + + // user 1 gets chat list... + { + // user 1 now has a (pre-existing) chat and a blast + chats, err := getUserChats(pool, ctx, userChatsParams{ + UserID: 1, + Limit: 10, + Before: time.Now().Add(time.Hour * 2).UTC(), + After: time.Now().Add(time.Hour * -2).UTC(), + }) + assert.NoError(t, err) + assert.Len(t, chats, 2) + + blastCount := 0 + for _, c := range chats { + if c.IsBlast { + blastCount++ + } + } + assert.Equal(t, "7eP5n:eYZmn", chats[1].ChatID) + assert.Equal(t, 1, blastCount) + } + + // user 100 (pre-existing) has a new message, but no new blasts + { + blasts, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 100, + }) + assert.NoError(t, err) + assert.Len(t, blasts, 0) + + messages = mustGetMessagesAndReactions(t, pool, ctx, 100, chatId_100_1) + assert.Len(t, messages, 2) + + messages = mustGetMessagesAndReactions(t, pool, ctx, 1, chatId_100_1) + assert.Len(t, messages, 2) + } + + // user 103 (pre-existing) has a new message, but no new blasts + { + blasts, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 103, + }) + assert.NoError(t, err) + assert.Len(t, blasts, 0) + + messages = mustGetMessagesAndReactions(t, pool, ctx, 103, chatId_1_103) + assert.Len(t, messages, 1) + + messages = mustGetMessagesAndReactions(t, pool, ctx, 1, chatId_1_103) + assert.Len(t, messages, 1) + } + + // user 101 has a blast + { + blasts, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 101, + }) + assert.NoError(t, err) + assert.Len(t, blasts, 1) + } + + // user 104 has zero blasts (inbox closed) + { + blasts, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 104, + }) + assert.NoError(t, err) + assert.Len(t, blasts, 0) + } + + // user 999 does not + { + assertChatCreateAllowed(t, ctx, validator, 999, 1, false) + + blasts, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 999, + }) + assert.NoError(t, err) + assert.Len(t, blasts, 0) + } + + // user 101 upgrades it to a real DM + { + + assertChatCreateAllowed(t, ctx, validator, 101, 1, true) + + err = chatCreate(pool, ctx, 101, t3, ChatCreateRPCParams{ + ChatID: chatId_101_1, + Invites: []PurpleInvite{ + {UserID: trashid.MustEncodeHashID(101), InviteCode: "earlier"}, + {UserID: trashid.MustEncodeHashID(1), InviteCode: "earlier"}, + }, + }) + assert.NoError(t, err) + + pool.QueryRow(ctx, `select count(*) from chat where chat_id = $1`, chatId_101_1).Scan(&count) + assert.Equal(t, 1, count) + + pool.QueryRow(ctx, `select count(*) from chat_member where chat_id = $1`, chatId_101_1).Scan(&count) + assert.Equal(t, 2, count) + + pool.QueryRow(ctx, `select count(*) from chat_member where is_hidden = false and chat_id = $1 and user_id = 101`, chatId_101_1).Scan(&count) + assert.Equal(t, 1, count) + + pool.QueryRow(ctx, `select count(*) from chat_member where is_hidden = true and chat_id = $1 and user_id = 1`, chatId_101_1).Scan(&count) + assert.Equal(t, 1, count) + + pool.QueryRow(ctx, `select count(*) from chat_message where chat_id = $1`, chatId_101_1).Scan(&count) + assert.Equal(t, 1, count) + + messages = mustGetMessagesAndReactions(t, pool, ctx, 101, chatId_101_1) + assert.Len(t, messages, 1) + } + + // after upgrade... user 101 has no pending blasts + { + blasts, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 101, + }) + assert.NoError(t, err) + assert.Len(t, blasts, 0) + } + + // after upgrade... user 101 has a chat + { + chats, err := getUserChats(pool, ctx, userChatsParams{ + UserID: 101, + Limit: 10, + Before: time.Now().Add(time.Hour * 12), + After: time.Now().Add(time.Hour * -12), + }) + assert.NoError(t, err) + assert.Len(t, chats, 1) + } + + // after upgrade... user 1 doesn't actually see the chat because it is hidden + { + chats, err := getUserChats(pool, ctx, userChatsParams{ + UserID: 1, + Limit: 50, + Before: time.Now().Add(time.Hour * 12), + After: time.Now().Add(time.Hour * -12), + }) + assert.NoError(t, err) + for _, chat := range chats { + if chat.ChatID == chatId_101_1 { + assert.Fail(t, "chat id should be hidden from user 1", chatId_101_1) + } + } + } + + // artist view: user 1 can get this blast + { + chat, err := getUserChat(pool, ctx, chatMembershipParams{ + UserID: 1, + ChatID: string(FollowerAudience), + }) + assert.NoError(t, err) + assert.Equal(t, string(FollowerAudience), chat.ChatID) + } + + // ----------------- a second message ------------------------ + + // Other user (104) re-opens inbox + err = chatSetPermissions(pool, ctx, 104, ChatPermissionAll, nil, nil, t3) + assert.NoError(t, err) + + outgoingMessages2, err := chatBlast(pool, ctx, 1, t4, ChatBlastRPCParams{ + BlastID: "b2", + Audience: FollowerAudience, + Message: "happy wed", + }) + assert.NoError(t, err) + + // Test that second blast also includes audience field + for _, outgoingMsg := range outgoingMessages2 { + assert.NotNil(t, outgoingMsg.ChatMessageRPC.Params.Audience, "Audience should be set in second blast outgoing message") + assert.Equal(t, FollowerAudience, *outgoingMsg.ChatMessageRPC.Params.Audience, "Audience should match the blast audience") + } + + pool.QueryRow(ctx, `select count(*) from chat_blast`).Scan(&count) + assert.Equal(t, 2, count) + + // user 101 above should have second blast added to the chat history... + { + chatId := trashid.ChatID(101, 1) + + pool.QueryRow(ctx, `select count(*) from chat_message where chat_id = $1`, chatId).Scan(&count) + assert.Equal(t, 2, count) + + messages = mustGetMessagesAndReactions(t, pool, ctx, 1, chatId) + assert.Len(t, messages, 2) + + assert.Equal(t, "happy wed", messages[0].Ciphertext) + assert.True(t, messages[0].IsPlaintext) + assert.Equal(t, "what up fam", messages[1].Ciphertext) + assert.True(t, messages[1].IsPlaintext) + assert.Greater(t, messages[0].CreatedAt, messages[1].CreatedAt) + + ch, err := getUserChat(pool, ctx, chatMembershipParams{ + UserID: 1, + ChatID: chatId, + }) + assert.NoError(t, err) + assert.True(t, ch.LastMessageIsPlaintext) + assert.Equal(t, "happy wed", ch.LastMessage.String) + + // user 101 reacts + { + heart := "heart" + chatReactMessage(pool, ctx, 101, chatId, messages[0].MessageID, &heart, t5) + + // reaction shows up + messages = mustGetMessagesAndReactions(t, pool, ctx, 1, chatId) + assert.Equal(t, "heart", messages[0].Reactions[0].Reaction) + } + + if false { + var debugRows []string + rows, err := pool.Query(ctx, `select row_to_json(c) from chat c;`) + assert.NoError(t, err) + defer rows.Close() + for rows.Next() { + var d string + err := rows.Scan(&debugRows) + assert.NoError(t, err) + fmt.Println("CHAT:", d) + } + } + + } + + // user 101 replies... now user 1 should see the thread + { + err = chatSendMessage(pool, ctx, 101, chatId_101_1, "respond_to_blast", t6, "101 responding to a blast from 1") + assert.NoError(t, err) + + chats, err := getUserChats(pool, ctx, userChatsParams{ + UserID: 1, + Limit: 50, + Before: time.Now().Add(time.Hour * 12), + After: time.Now().Add(time.Hour * -12), + }) + assert.NoError(t, err) + found := false + for _, chat := range chats { + if chat.ChatID == chatId_101_1 { + found = true + break + } + } + if !found { + assert.Fail(t, "chat id should now be visible to user 1", chatId_101_1) + } + } + + // user 104 should have just 1 blast + // since 104 opened inbox after first blast + { + blasts, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 104, + }) + assert.NoError(t, err) + assert.Len(t, blasts, 1) + + // 104 does upgrade + chatId_104_1 := trashid.ChatID(104, 1) + + err = chatCreate(pool, ctx, 104, t6, ChatCreateRPCParams{ + ChatID: chatId_104_1, + Invites: []PurpleInvite{ + {UserID: trashid.MustEncodeHashID(104), InviteCode: "earlier"}, + {UserID: trashid.MustEncodeHashID(1), InviteCode: "earlier"}, + }, + }) + assert.NoError(t, err) + + // 104 convo seeded with 1 message + + messages := mustGetMessagesAndReactions(t, pool, ctx, 104, chatId_104_1) + assert.Len(t, messages, 1) + messages = mustGetMessagesAndReactions(t, pool, ctx, 1, chatId_104_1) + assert.Len(t, messages, 1) + } + + // ------ sender can get blasts in a given thread ---------- + { + chat, err := getUserChat(pool, ctx, chatMembershipParams{ + UserID: 1, + ChatID: string(FollowerAudience), + }) + assert.NoError(t, err) + assert.Equal(t, string(FollowerAudience), chat.ChatID) + + messages, err := getChatMessagesAndReactions(pool, ctx, chatMessagesAndReactionsParams{ + UserID: 1, + ChatID: "follower_audience", + IsBlast: true, + Before: time.Now().Add(time.Hour * 2).UTC(), + After: time.Now().Add(time.Hour * -2).UTC(), + Limit: 10, + }) + assert.NoError(t, err) + assert.Len(t, messages, 2) + } + + // ------- bi-directional blasting works with upgrade -------- + + // 1 re-opens inbox + err = chatSetPermissions(pool, ctx, 1, ChatPermissionAll, nil, nil, t1) + assert.NoError(t, err) + + // 68 sends a blast + chatId_68_1 := trashid.ChatID(68, 1) + + _, err = chatBlast(pool, ctx, 68, t4, ChatBlastRPCParams{ + BlastID: "blast_from_68", + Audience: FollowerAudience, + Message: "I am 68", + }) + assert.NoError(t, err) + + // one side does upgrade + err = chatCreate(pool, ctx, 1, t5, ChatCreateRPCParams{ + ChatID: chatId_68_1, + Invites: []PurpleInvite{ + {UserID: trashid.MustEncodeHashID(68), InviteCode: "earlier"}, + {UserID: trashid.MustEncodeHashID(1), InviteCode: "earlier"}, + }, + }) + assert.NoError(t, err) + + // both parties should have 3 messages message + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 68, chatId_68_1) + assert.Len(t, messages, 3) + } + + // both parties should have 3 messages message + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 1, chatId_68_1) + assert.Len(t, messages, 3) + } +} + +func TestChatBlastTippers(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_comms") + defer pool.Close() + database.Seed(pool, database.FixtureMap{ + "users": { + {"user_id": 1, "wallet": "wallet1", "handle": "user1"}, + {"user_id": 201, "wallet": "wallet201", "handle": "user201"}, + }, + "user_tips": { + { + "sender_user_id": 201, + "receiver_user_id": 1, + "amount": 1000, + "slot": 101, + "signature": "tip_sig_123", + }, + }, + }) + + ctx := context.Background() + + // 1 sends blast to supporters + tipperOutgoing, err := chatBlast(pool, ctx, 1, time.Now().UTC(), ChatBlastRPCParams{ + BlastID: "blast_tippers_1", + Audience: TipperAudience, + Message: "thanks for your support", + }) + assert.NoError(t, err) + + // Test that tipper blast includes correct audience field + for _, outgoingMsg := range tipperOutgoing { + assert.NotNil(t, outgoingMsg.ChatMessageRPC.Params.Audience, "Audience should be set in tipper blast outgoing message") + assert.Equal(t, TipperAudience, *outgoingMsg.ChatMessageRPC.Params.Audience, "Audience should match the tipper audience") + } + + // 201 should have a pending blast + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 201, + }) + assert.NoError(t, err) + assert.Len(t, pending, 1) + } + + // 1 upgrades + chatId_1_201 := trashid.ChatID(1, 201) + err = chatCreate(pool, ctx, 101, time.Now().UTC(), ChatCreateRPCParams{ + ChatID: chatId_1_201, + Invites: []PurpleInvite{ + {UserID: trashid.MustEncodeHashID(1), InviteCode: "earlier"}, + {UserID: trashid.MustEncodeHashID(201), InviteCode: "earlier"}, + }, + }) + assert.NoError(t, err) + + // both users have 1 message + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 1, chatId_1_201) + assert.Len(t, messages, 1) + } + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 201, chatId_1_201) + assert.Len(t, messages, 1) + } + + // 201 should have no pending blast + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 201, + }) + assert.NoError(t, err) + assert.Len(t, pending, 0) + } + + { + chat, err := getUserChat(pool, ctx, chatMembershipParams{ + UserID: 1, + ChatID: string(TipperAudience), + }) + assert.NoError(t, err) + assert.Equal(t, string(TipperAudience), chat.ChatID) + } +} + +func TestChatBlastRemixers(t *testing.T) { + trackContentType := AudienceContentType("track") + pool := database.CreateTestDatabase(t, "test_comms") + defer pool.Close() + database.Seed(pool, database.FixtureMap{ + "users": { + {"user_id": 1, "wallet": "wallet1", "handle": "user1"}, + {"user_id": 202, "wallet": "wallet202", "handle": "user202"}, + }, + "tracks": { + { + "track_id": 1, + "owner_id": 1, + }, + { + "track_id": 2, + "owner_id": 202, + }, + }, + "remixes": { + { + "parent_track_id": 1, + "child_track_id": 2, + }, + }, + }) + + ctx := context.Background() + + // 1 sends blast to remixers + remixerOutgoing, err := chatBlast(pool, ctx, 1, time.Now().UTC(), ChatBlastRPCParams{ + BlastID: "blast_remixers_1", + Audience: RemixerAudience, + AudienceContentType: &trackContentType, + AudienceContentID: stringPointer(trashid.MustEncodeHashID(1)), + Message: "thanks for your remix", + }) + assert.NoError(t, err) + + // Test that remixer blast includes correct audience field + for _, outgoingMsg := range remixerOutgoing { + assert.NotNil(t, outgoingMsg.ChatMessageRPC.Params.Audience, "Audience should be set in remixer blast outgoing message") + assert.Equal(t, RemixerAudience, *outgoingMsg.ChatMessageRPC.Params.Audience, "Audience should match the remixer audience") + } + + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 202, + }) + assert.NoError(t, err) + assert.Len(t, pending, 1) + } + + // 1 sends another blast to all remixers + _, err = chatBlast(pool, ctx, 1, time.Now().UTC(), ChatBlastRPCParams{ + BlastID: "blast_remixers_2", + Audience: RemixerAudience, + Message: "new stems coming soon", + }) + assert.NoError(t, err) + + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 202, + }) + assert.NoError(t, err) + assert.Len(t, pending, 2) + } + + // 202 upgrades... should have 2 messages + chatId_202_1 := trashid.ChatID(202, 1) + err = chatCreate(pool, ctx, 202, time.Now().UTC(), ChatCreateRPCParams{ + ChatID: chatId_202_1, + Invites: []PurpleInvite{ + {UserID: trashid.MustEncodeHashID(202), InviteCode: "earlier"}, + {UserID: trashid.MustEncodeHashID(1), InviteCode: "earlier"}, + }, + }) + assert.NoError(t, err) + + // both users have 2 messages + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 202, chatId_202_1) + assert.Len(t, messages, 2) + } + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 1, chatId_202_1) + assert.Len(t, messages, 2) + } + + _, err = chatBlast(pool, ctx, 1, time.Now().UTC(), ChatBlastRPCParams{ + BlastID: "blast_remixers_3", + Audience: RemixerAudience, + AudienceContentType: &trackContentType, + AudienceContentID: stringPointer(trashid.MustEncodeHashID(1)), + Message: "yall are the best", + }) + assert.NoError(t, err) + + // both users have 3 messages + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 202, chatId_202_1) + assert.Len(t, messages, 3) + } + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 1, chatId_202_1) + assert.Len(t, messages, 3) + } + + { + blastChatId := "remixer_audience:track:" + trashid.MustEncodeHashID(1) + chat, err := getUserChat(pool, ctx, chatMembershipParams{ + UserID: 1, + ChatID: blastChatId, + }) + assert.NoError(t, err) + assert.Equal(t, blastChatId, chat.ChatID) + } + + { + chat, err := getUserChat(pool, ctx, chatMembershipParams{ + UserID: 1, + ChatID: "remixer_audience", + }) + assert.NoError(t, err) + assert.Equal(t, "remixer_audience", chat.ChatID) + } + +} + +func TestChatBlastPurchasers(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_comms") + defer pool.Close() + database.Seed(pool, database.FixtureMap{ + "users": { + {"user_id": 1, "wallet": "wallet1", "handle": "user1"}, + {"user_id": 203, "wallet": "wallet203", "handle": "user203"}, + }, + "tracks": { + { + "track_id": 1, + "owner_id": 1, + }, + }, + "usdc_purchases": { + { + "buyer_user_id": 203, + "seller_user_id": 1, + "content_type": "track", + "content_id": 1, + "amount": 5990000, // 5.99USDC in micro-units + "signature": "purchase_sig_123", + "slot": 101, + }, + }, + }) + + ctx := context.Background() + + _, err := chatBlast(pool, ctx, 1, time.Now().UTC(), ChatBlastRPCParams{ + BlastID: "blast_customers_1", + Audience: CustomerAudience, + Message: "thank you for yr purchase", + }) + assert.NoError(t, err) + + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 203, + }) + assert.NoError(t, err) + assert.Len(t, pending, 1) + } + + { + chat, err := getUserChat(pool, ctx, chatMembershipParams{ + UserID: 1, + ChatID: "customer_audience", + }) + assert.NoError(t, err) + assert.Equal(t, "customer_audience", chat.ChatID) + } + + // no blasts for a specific track customer yet... so this is a not found error + { + _, err := getUserChat(pool, ctx, chatMembershipParams{ + UserID: 1, + ChatID: "customer_audience:track:1", + }) + assert.Error(t, err) + } +} + +func TestChatBlastCoinHolders(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_comms") + defer pool.Close() + database.Seed(pool, database.FixtureMap{ + "users": { + {"user_id": 1, "wallet": "wallet1", "handle": "user1"}, + {"user_id": 204, "wallet": "wallet204", "handle": "user204"}, + {"user_id": 205, "wallet": "wallet205", "handle": "user205"}, + {"user_id": 206, "wallet": "wallet206", "handle": "user206"}, + }, + "artist_coins": { + { + "user_id": 1, + "ticker": "$ARTIST1", + "mint": "mint123", + "decimals": 8, + }, + }, + "sol_claimable_accounts": { + { + "signature": "sig1", + "account": "account204", + "ethereum_address": "wallet204", + "mint": "mint123", + }, + { + "signature": "sig2", + "account": "account205", + "ethereum_address": "wallet205", + "mint": "mint123", + }, + { + "signature": "sig3", + "account": "account206", + "ethereum_address": "wallet206", + "mint": "mint123", + }, + }, + }) + + ctx := context.Background() + + _, err := pool.Exec(ctx, ` + insert into sol_token_account_balance_changes + (signature, mint, owner, account, change, balance, slot, created_at, block_timestamp) + values + -- user 204: positive balance before blast + ('tx1', 'mint123', 'wallet204', 'account204', 1000, 1000, 10001, $1, $1), + ('tx2', 'mint123', 'wallet206', 'account206', 500, 500, 10003, $1, $1); + `, time.Now().UTC()) + assert.NoError(t, err) + + _, err = pool.Exec(ctx, ` + insert into sol_token_account_balance_changes + (signature, mint, owner, account, change, balance, slot, created_at, block_timestamp) + values + -- user 206: had positive balance, then sold to zero before blast + ('tx3', 'mint123', 'wallet206', 'account206', -500, 0, 10004, $1, $1); + `, time.Now().UTC()) + assert.NoError(t, err) + + // 1 sends blast to coin holders + _, err = chatBlast(pool, ctx, 1, time.Now().UTC(), ChatBlastRPCParams{ + BlastID: "blast_coin_holders_1", + Audience: CoinHolderAudience, + Message: "thanks for holding my coin", + }) + assert.NoError(t, err) + + // Only user 204 should have a pending blast (has positive balance) + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 204, + }) + assert.NoError(t, err) + assert.Len(t, pending, 1) + } + + // User 205 should have no pending blast (zero balance) + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 205, + }) + assert.NoError(t, err) + assert.Len(t, pending, 0) + } + + // User 206 should have no pending blast (sold before blast) + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 206, + }) + assert.NoError(t, err) + assert.Len(t, pending, 0) + } + + // 204 upgrades to real DM + chatId_204_1 := trashid.ChatID(204, 1) + err = chatCreate(pool, ctx, 204, time.Now().UTC(), ChatCreateRPCParams{ + ChatID: chatId_204_1, + Invites: []PurpleInvite{ + {UserID: trashid.MustEncodeHashID(204), InviteCode: "earlier"}, + {UserID: trashid.MustEncodeHashID(1), InviteCode: "earlier"}, + }, + }) + assert.NoError(t, err) + + // Both users should have 1 message + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 204, chatId_204_1) + assert.Len(t, messages, 1) + assert.Equal(t, "thanks for holding my coin", messages[0].Ciphertext) + } + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 1, chatId_204_1) + assert.Len(t, messages, 1) + } + + // 204 should have no pending blast after upgrade + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 204, + }) + assert.NoError(t, err) + assert.Len(t, pending, 0) + } + + // Test that new balance changes after blast don't affect existing blast + _, err = pool.Exec(ctx, ` + insert into sol_token_account_balance_changes + (signature, mint, owner, account, change, balance, slot, created_at, block_timestamp) + values + -- user 205 gets tokens AFTER the blast + ('tx5', 'mint123', 'wallet205', 'account205', 2000, 2000, 10005, $1, $1); + `, time.Now().UTC()) + assert.NoError(t, err) + + // User 205 still should have no pending blast (balance change was after blast) + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 205, + }) + assert.NoError(t, err) + assert.Len(t, pending, 0) + } + + // Send another blast - now 205 should be included + _, err = chatBlast(pool, ctx, 1, time.Now().UTC(), ChatBlastRPCParams{ + BlastID: "blast_coin_holders_2", + Audience: CoinHolderAudience, + Message: "welcome new holders", + }) + assert.NoError(t, err) + + // Now user 205 should have a pending blast + { + pending, err := getNewBlasts(pool, ctx, getNewBlastsParams{ + UserID: 205, + }) + assert.NoError(t, err) + assert.Len(t, pending, 1) + } + + // User 204 should have the new blast added to existing chat + { + messages := mustGetMessagesAndReactions(t, pool, ctx, 204, chatId_204_1) + assert.Len(t, messages, 2) + assert.Equal(t, "welcome new holders", messages[0].Ciphertext) + assert.Equal(t, "thanks for holding my coin", messages[1].Ciphertext) + } + + // Test blast chat view for sender + { + chat, err := getUserChat(pool, ctx, chatMembershipParams{ + UserID: 1, + ChatID: "coin_holder_audience", + }) + assert.NoError(t, err) + assert.Equal(t, "coin_holder_audience", chat.ChatID) + } +} + +func stringPointer(val string) *string { + return &val +} diff --git a/api/comms/chat_block_test.go b/api/comms/chat_block_test.go index 51e46522..5c3e5b4b 100644 --- a/api/comms/chat_block_test.go +++ b/api/comms/chat_block_test.go @@ -11,11 +11,9 @@ import ( "bridgerton.audius.co/database" "bridgerton.audius.co/trashid" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestChatBlocking(t *testing.T) { - // Create test database pool := database.CreateTestDatabase(t, "test_comms") defer pool.Close() @@ -24,12 +22,6 @@ func TestChatBlocking(t *testing.T) { // Create validator for validation testing validator := CreateTestValidator(t, pool, DefaultRateLimitConfig, DefaultTestValidatorConfig) - // reset tables under test - _, err := pool.Exec(ctx, "truncate table chat_blocked_users cascade") - require.NoError(t, err) - _, err = pool.Exec(ctx, "truncate table chat cascade") - require.NoError(t, err) - seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) user1Id := seededRand.Int31() user2Id := seededRand.Int31() @@ -37,7 +29,7 @@ func TestChatBlocking(t *testing.T) { assertBlocked := func(blockerUserId int32, blockeeUserId int32, timestamp time.Time, expected int) { row := pool.QueryRow(ctx, "select count(*) from chat_blocked_users where blocker_user_id = $1 and blockee_user_id = $2 and created_at = $3", blockerUserId, blockeeUserId, timestamp) var count int - err = row.Scan(&count) + err := row.Scan(&count) assert.NoError(t, err) assert.Equal(t, expected, count) } @@ -51,7 +43,7 @@ func TestChatBlocking(t *testing.T) { Params: []byte(fmt.Sprintf(`{"user_id": "%s"}`, encodedUserId)), } - err = validator.validateChatBlock(user1Id, exampleRpc) + err := validator.validateChatBlock(user1Id, exampleRpc) assert.NoError(t, err) } diff --git a/api/comms/chat_create_test.go b/api/comms/chat_create_test.go index 1e6ced44..367b7d31 100644 --- a/api/comms/chat_create_test.go +++ b/api/comms/chat_create_test.go @@ -11,7 +11,6 @@ import ( ) func TestChatCreate(t *testing.T) { - // Create test database pool := database.CreateTestDatabase(t, "test_comms") defer pool.Close() diff --git a/api/comms/chat_delete_test.go b/api/comms/chat_delete_test.go index 60c75765..73fac7d2 100644 --- a/api/comms/chat_delete_test.go +++ b/api/comms/chat_delete_test.go @@ -15,7 +15,6 @@ import ( ) func TestChatDeletion(t *testing.T) { - // Create test database pool := database.CreateTestDatabase(t, "test_comms") defer pool.Close() diff --git a/api/comms/chat_permissions_test.go b/api/comms/chat_permissions_test.go index 5a5b6f25..626b7dab 100644 --- a/api/comms/chat_permissions_test.go +++ b/api/comms/chat_permissions_test.go @@ -14,7 +14,6 @@ import ( ) func TestChatPermissions(t *testing.T) { - // Create test database pool := database.CreateTestDatabase(t, "test_comms") defer pool.Close() @@ -25,18 +24,28 @@ func TestChatPermissions(t *testing.T) { user2Id := seededRand.Int31() user3Id := seededRand.Int31() - // Set up test data - // user 1 follows user 2 - _, err := pool.Exec(ctx, "insert into follows (follower_user_id, followee_user_id, is_current, is_delete, created_at) values ($1, $2, true, false, now())", user1Id, user2Id) - require.NoError(t, err) - // user 3 has tipped user 1 - _, err = pool.Exec(ctx, ` - insert into user_tips - (slot, signature, sender_user_id, receiver_user_id, amount, created_at, updated_at) - values - (1, 'c', $1, $2, 100, now(), now()) - `, user3Id, user1Id) - require.NoError(t, err) + database.Seed(pool, database.FixtureMap{ + "users": { + {"user_id": user1Id, "wallet": "wallet1", "handle": "user1"}, + {"user_id": user2Id, "wallet": "wallet2", "handle": "user2"}, + {"user_id": user3Id, "wallet": "wallet3", "handle": "user3"}, + }, + "follows": { + { + "follower_user_id": user1Id, + "followee_user_id": user2Id, + }, + }, + "user_tips": { + { + "slot": 101, + "signature": "c", + "amount": 100, + "sender_user_id": user3Id, + "receiver_user_id": user1Id, + }, + }, + }) // Create validator for validation testing validator := CreateTestValidator(t, pool, DefaultRateLimitConfig, DefaultTestValidatorConfig) diff --git a/api/comms/chat_test.go b/api/comms/chat_test.go index d014cfaa..8282563c 100644 --- a/api/comms/chat_test.go +++ b/api/comms/chat_test.go @@ -15,7 +15,6 @@ import ( ) func TestChat(t *testing.T) { - // Create test database pool := database.CreateTestDatabase(t, "test_comms") defer pool.Close() diff --git a/api/comms/chat_test_queries.go b/api/comms/chat_test_queries.go new file mode 100644 index 00000000..86cf8c1e --- /dev/null +++ b/api/comms/chat_test_queries.go @@ -0,0 +1,338 @@ +package comms + +import ( + "context" + "encoding/json" + "errors" + "strings" + "time" + + "bridgerton.audius.co/api/dbv1" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +type chatMessagesAndReactionsParams struct { + UserID int32 `db:"user_id" json:"user_id"` + ChatID string `db:"chat_id" json:"chat_id"` + Limit int32 `json:"limit"` + Before time.Time `json:"before"` + After time.Time `json:"after"` + IsBlast bool `json:"is_blast"` +} +type chatMessageAndReactionsRow struct { + MessageID string `db:"message_id" json:"message_id"` + ChatID string `db:"chat_id" json:"chat_id"` + UserID int32 `db:"user_id" json:"user_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + Ciphertext string `db:"ciphertext" json:"ciphertext"` + IsPlaintext bool `db:"is_plaintext" json:"is_plaintext"` + Reactions reactions `json:"reactions"` +} +type chatMessageReactionRow struct { + UserID int32 `db:"user_id" json:"user_id"` + MessageID string `db:"message_id" json:"message_id"` + Reaction string `db:"reaction" json:"reaction"` + CreatedAt JSONTime `db:"created_at" json:"created_at"` + UpdatedAt JSONTime `db:"updated_at" json:"updated_at"` +} + +type JSONTime struct { + time.Time +} + +// Override JSONB timestamp unmarshaling since the postgres driver +// does not convert timestamp strings in JSON -> time.Time +func (t *JSONTime) UnmarshalJSON(b []byte) error { + timeformat := "2006-01-02T15:04:05.999999" + var timestamp string + err := json.Unmarshal(b, ×tamp) + if err != nil { + return err + } + t.Time, err = time.Parse(timeformat, timestamp) + if err != nil { + return err + } + return nil +} + +type reactions []chatMessageReactionRow + +func (reactions *reactions) Scan(value interface{}) error { + if value == nil { + *reactions = nil + return nil + } + + switch v := value.(type) { + case []byte: + return json.Unmarshal(v, reactions) + case string: + return json.Unmarshal([]byte(v), reactions) + default: + return errors.New("type assertion failed: expected []byte or string for JSON scanning") + } +} + +func getChatMessagesAndReactions(db dbv1.DBTX, ctx context.Context, arg chatMessagesAndReactionsParams) ([]chatMessageAndReactionsRow, error) { + // special case to handle outgoing blasts... + if arg.IsBlast { + parts := strings.Split(arg.ChatID, ":") + if len(parts) < 1 { + return nil, errors.New("bad request: invalid blast id") + } + audience := parts[0] + + if ChatBlastAudience(audience) == FollowerAudience || + ChatBlastAudience(audience) == TipperAudience || + ChatBlastAudience(audience) == CustomerAudience || + ChatBlastAudience(audience) == RemixerAudience { + + result, err := db.Query(ctx, ` + SELECT + b.blast_id as message_id, + @chat_id as chat_id, + b.from_user_id as user_id, + b.created_at, + b.plaintext as ciphertext, + true as is_plaintext, + '[]'::json AS reactions + FROM chat_blast b + WHERE b.from_user_id = @user_id + AND concat_ws(':', audience, audience_content_type, + CASE + WHEN audience_content_id IS NOT NULL THEN id_encode(audience_content_id) + ELSE NULL + END) = @chat_id + AND b.created_at < @before + AND b.created_at > @after + ORDER BY b.created_at DESC + LIMIT @limit + `, + pgx.NamedArgs{ + "user_id": arg.UserID, + "chat_id": arg.ChatID, + "before": arg.Before, + "after": arg.After, + "limit": arg.Limit, + }, + ) + if err != nil { + return nil, err + } + + return pgx.CollectRows(result, pgx.RowToStructByName[chatMessageAndReactionsRow]) + } else { + return nil, errors.New("bad request: unsupported audience " + audience) + } + } + + result, err := db.Query(ctx, ` + SELECT + chat_message.message_id, + chat_message.chat_id, + chat_message.user_id, + chat_message.created_at, + COALESCE(chat_message.ciphertext, chat_blast.plaintext) as ciphertext, + chat_blast.plaintext is not null as is_plaintext, + to_json(array(select row_to_json(r) from chat_message_reactions r where chat_message.message_id = r.message_id)) AS reactions + FROM chat_message + JOIN chat_member ON chat_message.chat_id = chat_member.chat_id + LEFT JOIN chat_blast USING (blast_id) + WHERE chat_member.user_id = @user_id + AND chat_message.chat_id = @chat_id + AND chat_message.created_at < @before + AND chat_message.created_at > @after + AND (chat_member.cleared_history_at IS NULL + OR chat_message.created_at > chat_member.cleared_history_at + ) + ORDER BY chat_message.created_at DESC, chat_message.message_id + LIMIT @limit`, + pgx.NamedArgs{ + "user_id": arg.UserID, + "chat_id": arg.ChatID, + "before": arg.Before, + "after": arg.After, + "limit": arg.Limit, + }, + ) + if err != nil { + return nil, err + } + + return pgx.CollectRows(result, pgx.RowToStructByName[chatMessageAndReactionsRow]) +} + +type chatMembershipParams struct { + UserID int32 `db:"user_id" json:"user_id"` + ChatID string `db:"chat_id" json:"chat_id"` +} + +type userChatRow struct { + ChatID string `db:"chat_id" json:"chat_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + LastMessage pgtype.Text `db:"last_message" json:"last_message"` + LastMessageAt time.Time `db:"last_message_at" json:"last_message_at"` + LastMessageIsPlaintext bool `db:"last_message_is_plaintext" json:"last_message_is_plaintext"` + InviteCode string `db:"invite_code" json:"invite_code"` + LastActiveAt pgtype.Timestamp `db:"last_active_at" json:"last_active_at"` + UnreadCount int32 `db:"unread_count" json:"unread_count"` + ClearedHistoryAt pgtype.Timestamp `db:"cleared_history_at" json:"cleared_history_at"` + IsBlast bool `db:"is_blast" json:"is_blast"` + Audience pgtype.Text `db:"audience" json:"audience"` + AudienceContentType pgtype.Text `db:"audience_content_type" json:"audience_content_type"` + AudienceContentID pgtype.Int4 `db:"audience_content_id" json:"audience_content_id"` +} + +const userChatQuery = ` +SELECT + chat.chat_id, + chat.created_at, + chat.last_message, + chat.last_message_at, + chat.last_message_is_plaintext, + chat_member.invite_code, + chat_member.last_active_at, + chat_member.unread_count, + chat_member.cleared_history_at, + false as is_blast, + null as audience, + null as audience_content_type, + null as audience_content_id +FROM chat_member +JOIN chat ON chat.chat_id = chat_member.chat_id +WHERE chat_member.user_id = @user_id AND chat_member.chat_id = @chat_id + +union all ( + + SELECT DISTINCT ON (audience, audience_content_type, audience_content_id) + concat_ws(':', audience, audience_content_type, + CASE + WHEN audience_content_id IS NOT NULL THEN id_encode(audience_content_id) + ELSE NULL + END) as chat_id, + min(created_at) over (partition by audience, audience_content_type, audience_content_id) as created_at, + plaintext as last_message, + max(created_at) over (partition by audience, audience_content_type, audience_content_id) as last_message_at, + true as last_message_is_plaintext, + '' as invite_code, + created_at as last_active_at, + 0 as unread_count, + null as cleared_history_at, + true as is_blast, + audience, + audience_content_type, + audience_content_id + FROM chat_blast b + WHERE from_user_id = @user_id + AND concat_ws(':', audience, audience_content_type, + CASE + WHEN audience_content_id IS NOT NULL THEN id_encode(audience_content_id) + ELSE NULL + END) = @chat_id + ORDER BY + audience, + audience_content_type, + audience_content_id, + created_at DESC +) +` + +func getUserChat(db dbv1.DBTX, ctx context.Context, arg chatMembershipParams) (userChatRow, error) { + rows, err := db.Query(ctx, userChatQuery, pgx.NamedArgs{ + "user_id": arg.UserID, + "chat_id": arg.ChatID, + }) + if err != nil { + return userChatRow{}, err + } + defer rows.Close() + + row, err := pgx.CollectExactlyOneRow(rows, pgx.RowToStructByName[userChatRow]) + if err != nil { + return userChatRow{}, err + } + return row, nil +} + +type userChatsParams struct { + UserID int32 `db:"user_id" json:"user_id"` + Limit int32 `json:"limit"` + Before time.Time `json:"before"` + After time.Time `json:"after"` +} + +const userChatsQuery = ` +SELECT + chat.chat_id, + chat.created_at, + chat.last_message, + chat.last_message_at, + chat.last_message_is_plaintext, + chat_member.invite_code, + chat_member.last_active_at, + chat_member.unread_count, + chat_member.cleared_history_at, + false as is_blast, + null as audience, + null as audience_content_type, + null as audience_content_id +FROM chat_member +JOIN chat ON chat.chat_id = chat_member.chat_id +WHERE chat_member.user_id = @user_id + AND chat_member.is_hidden = false + AND chat.last_message IS NOT NULL + AND chat.last_message_at < @before + AND chat.last_message_at > @after + AND (chat_member.cleared_history_at IS NULL + OR chat.last_message_at > chat_member.cleared_history_at) + + +union all ( + + SELECT DISTINCT ON (audience, audience_content_type, audience_content_id) + concat_ws(':', audience, audience_content_type, + CASE + WHEN audience_content_id IS NOT NULL THEN id_encode(audience_content_id) + ELSE NULL + END) as chat_id, + min(created_at) over (partition by audience, audience_content_type, audience_content_id) as created_at, + plaintext as last_message, + max(created_at) over (partition by audience, audience_content_type, audience_content_id) as last_message_at, + true as last_message_is_plaintext, + '' as invite_code, + created_at as last_active_at, + 0 as unread_count, + null as cleared_history_at, + true as is_blast, + audience, + audience_content_type, + audience_content_id + FROM chat_blast b + WHERE from_user_id = @user_id + AND b.created_at < @before + AND b.created_at > @after + ORDER BY + audience, + audience_content_type, + audience_content_id, + created_at DESC +) + +ORDER BY last_message_at DESC, is_blast DESC, chat_id ASC +LIMIT @limit +` + +func getUserChats(db dbv1.DBTX, ctx context.Context, arg userChatsParams) ([]userChatRow, error) { + rows, err := db.Query(ctx, userChatsQuery, pgx.NamedArgs{ + "user_id": arg.UserID, + "limit": arg.Limit, + "before": arg.Before, + "after": arg.After, + }) + if err != nil { + return nil, err + } + return pgx.CollectRows(rows, pgx.RowToStructByName[userChatRow]) +} diff --git a/api/comms/rate_limit_test.go b/api/comms/rate_limit_test.go index 419f5ea7..311f0a55 100644 --- a/api/comms/rate_limit_test.go +++ b/api/comms/rate_limit_test.go @@ -9,7 +9,6 @@ import ( "bridgerton.audius.co/database" "bridgerton.audius.co/trashid" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestBurstRateLimit(t *testing.T) { @@ -19,10 +18,6 @@ func TestBurstRateLimit(t *testing.T) { ctx := context.Background() - // reset tables under test - _, err := pool.Exec(ctx, "truncate table chat cascade") - require.NoError(t, err) - chatId := trashid.ChatID(1, 2) // Use deterministic chat ID user1Id := int32(1) user2Id := int32(2) @@ -38,7 +33,7 @@ func TestBurstRateLimit(t *testing.T) { // hit the 1 second limit... send a burst of messages for i := 1; i < 5; i++ { message := fmt.Sprintf("burst %d", i) - err = chatSendMessage(pool, ctx, user1Id, chatId, message, time.Now().UTC(), message) + err := chatSendMessage(pool, ctx, user1Id, chatId, message, time.Now().UTC(), message) assert.NoError(t, err, "i is", i) messageRpc := RawRPC{ diff --git a/api/comms_mutate_test.go b/api/comms_mutate_test.go index 4f91abed..b7576469 100644 --- a/api/comms_mutate_test.go +++ b/api/comms_mutate_test.go @@ -2,6 +2,7 @@ package api import ( "encoding/json" + "fmt" "strings" "testing" "time" @@ -51,7 +52,6 @@ func postMutateRPCData(t *testing.T, app *ApiServer, currentUserID string, metho // protocol repo) in the comms package func TestPostMutateChat(t *testing.T) { testWallet1 := testdata.CreateTestWallet(t, user1WalletKey) - testWallet2 := testdata.CreateTestWallet(t, user2WalletKey) app := emptyTestApp(t) // Setup test data @@ -69,7 +69,7 @@ func TestPostMutateChat(t *testing.T) { { "user_id": 2, "handle": "user2", - "wallet": strings.ToLower(testWallet2.Address), + "wallet": "0x7d273271690538cf855e5b3002a0dd8c154bb060", "created_at": now.Add(-time.Hour), "updated_at": now.Add(-time.Hour), "is_current": true, @@ -83,8 +83,9 @@ func TestPostMutateChat(t *testing.T) { var user2EncodedID = trashid.MustEncodeHashID(2) t.Run("valid create, skip dupes", func(t *testing.T) { + chatId := trashid.ChatID(1, 2) params := comms.ChatCreateRPCParams{ - ChatID: trashid.ChatID(1, 2), + ChatID: chatId, Invites: []comms.PurpleInvite{ { UserID: user1EncodedID, @@ -98,10 +99,18 @@ func TestPostMutateChat(t *testing.T) { } { - // Test getting regular chat messages (not blasts) status, _ := postMutateRPCData(t, app, user1EncodedID, comms.RPCMethodChatCreate, params, now.UnixMilli(), testWallet1) assert.Equal(t, 200, status) - // TODO: Fetch and check it + + url := fmt.Sprintf("/comms/chats/%s", chatId) + + status, body := testGetWithWallet(t, app, url, "0x7d273271690538cf855e5b3002a0dd8c154bb060") + assert.Equal(t, 200, status) + jsonAssert(t, body, map[string]any{ + "data.invite_code": "test", + "data.chat_members.0.user_id": user1EncodedID, + "data.chat_members.1.user_id": user2EncodedID, + }) } { @@ -111,9 +120,3 @@ func TestPostMutateChat(t *testing.T) { } }) } - -/* TODO: -- 403 when attestation fails -- 400 when we can't get user id from wallet -- 400 when readSignedPost fails -*/ diff --git a/api/dbv1/chat_messages_row.go b/api/dbv1/chat_messages_row.go index e34d6e21..78f75e64 100644 --- a/api/dbv1/chat_messages_row.go +++ b/api/dbv1/chat_messages_row.go @@ -34,12 +34,19 @@ type JSONTime struct { type Reactions []ChatMessageReactionRow func (reactions *Reactions) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New("type assertion to []byte failed") + if value == nil { + *reactions = nil + return nil } - return json.Unmarshal(bytes, reactions) + switch v := value.(type) { + case []byte: + return json.Unmarshal(v, reactions) + case string: + return json.Unmarshal([]byte(v), reactions) + default: + return errors.New("type assertion failed: expected []byte or string for JSON scanning") + } } // Override JSONB timestamp unmarshaling since the postgres driver diff --git a/solana/indexer/db_insert_test.go b/solana/indexer/db_insert_test.go index 09d2943b..c41fa9ff 100644 --- a/solana/indexer/db_insert_test.go +++ b/solana/indexer/db_insert_test.go @@ -332,9 +332,9 @@ func TestInsertBalanceChangeTriggers(t *testing.T) { // Now associate the wallet and verify the user balance is updated _, err = pool.Exec(t.Context(), - `INSERT INTO associated_wallets + `INSERT INTO associated_wallets (id, user_id, wallet, chain, blockhash, blocknumber, is_current, is_delete) - VALUES + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) `, 3, 1, "owner3", "sol", "blockhash3", 101, true, false,