diff --git a/shock-server/controller/node/multi.go b/shock-server/controller/node/multi.go index fdb1854..7964c96 100644 --- a/shock-server/controller/node/multi.go +++ b/shock-server/controller/node/multi.go @@ -1,6 +1,7 @@ package node import ( + "encoding/json" "fmt" "github.com/MG-RAST/Shock/shock-server/conf" e "github.com/MG-RAST/Shock/shock-server/errors" @@ -101,6 +102,28 @@ func (cr *NodeController) ReadMany(w http.ResponseWriter, r *http.Request) { } } } + } else if queryJSON, ok := query["querymongo"]; ok { + rawJSON := queryJSON[0] + if len(rawJSON) > maxQueryJSONSize { + err_msg := fmt.Sprintf("err@node_ReadMany: querymongo JSON exceeds maximum size of %d bytes", maxQueryJSONSize) + logger.Error(err_msg) + responder.RespondWithError(w, r, http.StatusBadRequest, err_msg) + return + } + var userQuery bson.M + if err := json.Unmarshal([]byte(rawJSON), &userQuery); err != nil { + err_msg := "err@node_ReadMany: invalid JSON in querymongo parameter: " + err.Error() + logger.Error(err_msg) + responder.RespondWithError(w, r, http.StatusBadRequest, err_msg) + return + } + if err := sanitizeQuery(userQuery); err != nil { + err_msg := "err@node_ReadMany: " + err.Error() + logger.Error(err_msg) + responder.RespondWithError(w, r, http.StatusBadRequest, err_msg) + return + } + qOpts = userQuery } if len(OptsMArray) > 0 { @@ -372,3 +395,32 @@ func parseTypedValue(i *interface{}) { } return } + +var blockedOperators = map[string]bool{ + "$where": true, "$expr": true, "$function": true, "$accumulator": true, +} + +const maxQueryJSONSize = 16 * 1024 + +func sanitizeQuery(m map[string]interface{}) error { + for key, val := range m { + if blockedOperators[key] { + return fmt.Errorf("operator %s is not allowed in passthrough queries", key) + } + switch v := val.(type) { + case map[string]interface{}: + if err := sanitizeQuery(v); err != nil { + return err + } + case []interface{}: + for _, item := range v { + if subMap, ok := item.(map[string]interface{}); ok { + if err := sanitizeQuery(subMap); err != nil { + return err + } + } + } + } + } + return nil +} diff --git a/shock-server/controller/node/multi_test.go b/shock-server/controller/node/multi_test.go new file mode 100644 index 0000000..da0b8d6 --- /dev/null +++ b/shock-server/controller/node/multi_test.go @@ -0,0 +1,112 @@ +package node + +import ( + "testing" +) + +func TestSanitizeQuery_AllowsSafeOperators(t *testing.T) { + tests := []struct { + name string + query map[string]interface{} + }{ + { + name: "simple field match", + query: map[string]interface{}{"file.name": "test.fasta"}, + }, + { + name: "$gt operator", + query: map[string]interface{}{"file.size": map[string]interface{}{"$gt": 1000}}, + }, + { + name: "$in operator", + query: map[string]interface{}{"file.name": map[string]interface{}{"$in": []interface{}{"a", "b"}}}, + }, + { + name: "$exists operator", + query: map[string]interface{}{"attributes.project": map[string]interface{}{"$exists": true}}, + }, + { + name: "$and with $elemMatch", + query: map[string]interface{}{ + "$and": []interface{}{ + map[string]interface{}{"file.size": map[string]interface{}{"$gt": 0}}, + map[string]interface{}{"tags": map[string]interface{}{"$elemMatch": map[string]interface{}{"$eq": "metagenome"}}}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := sanitizeQuery(tt.query); err != nil { + t.Errorf("sanitizeQuery() returned unexpected error: %v", err) + } + }) + } +} + +func TestSanitizeQuery_BlocksDangerousOperators(t *testing.T) { + tests := []struct { + name string + query map[string]interface{} + wantOp string + }{ + { + name: "$where at top level", + query: map[string]interface{}{"$where": "this.file.size > 1000"}, + wantOp: "$where", + }, + { + name: "$expr at top level", + query: map[string]interface{}{"$expr": map[string]interface{}{"$gt": []interface{}{"$file.size", 1000}}}, + wantOp: "$expr", + }, + { + name: "$function at top level", + query: map[string]interface{}{"$function": map[string]interface{}{"body": "function() { return true; }"}}, + wantOp: "$function", + }, + { + name: "$accumulator at top level", + query: map[string]interface{}{"$accumulator": map[string]interface{}{"init": "function() {}"}}, + wantOp: "$accumulator", + }, + { + name: "$where nested inside $or", + query: map[string]interface{}{ + "$or": []interface{}{ + map[string]interface{}{"file.name": "test"}, + map[string]interface{}{"$where": "this.file.size > 0"}, + }, + }, + wantOp: "$where", + }, + { + name: "$expr nested inside $and inside $or", + query: map[string]interface{}{ + "$or": []interface{}{ + map[string]interface{}{ + "$and": []interface{}{ + map[string]interface{}{"$expr": map[string]interface{}{"$gt": []interface{}{"$a", "$b"}}}, + }, + }, + }, + }, + wantOp: "$expr", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := sanitizeQuery(tt.query) + if err == nil { + t.Errorf("sanitizeQuery() expected error for operator %s, got nil", tt.wantOp) + return + } + expected := "operator " + tt.wantOp + " is not allowed in passthrough queries" + if err.Error() != expected { + t.Errorf("sanitizeQuery() error = %q, want %q", err.Error(), expected) + } + }) + } +}