diff --git a/nodes/node_balancer/cmd/nodebalancer/middleware.go b/nodes/node_balancer/cmd/nodebalancer/middleware.go index 7f9e4a0e..81811a5e 100644 --- a/nodes/node_balancer/cmd/nodebalancer/middleware.go +++ b/nodes/node_balancer/cmd/nodebalancer/middleware.go @@ -167,6 +167,31 @@ func panicMiddleware(next http.Handler) http.Handler { }) } +// Split JSON RPC request to object and slice and return slice of requests +func jsonrpcRequestParser(body []byte) ([]JSONRPCRequest, error) { + var jsonrpcRequest []JSONRPCRequest + + firstByte := bytes.TrimLeft(body, " \t\r\n") + switch { + case len(firstByte) > 0 && firstByte[0] == '[': + err := json.Unmarshal(body, &jsonrpcRequest) + if err != nil { + return nil, fmt.Errorf("Unable to parse body, err: %v", err) + } + case len(firstByte) > 0 && firstByte[0] == '{': + var singleJsonrpcRequest JSONRPCRequest + err := json.Unmarshal(body, &singleJsonrpcRequest) + if err != nil { + return nil, fmt.Errorf("Unable to parse body, err: %v", err) + } + jsonrpcRequest = []JSONRPCRequest{singleJsonrpcRequest} + default: + return nil, fmt.Errorf("Incorrect first byte in JSON RPC request") + } + + return jsonrpcRequest, nil +} + // Log access requests in proper format func logMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -198,12 +223,20 @@ func logMiddleware(next http.Handler) http.Handler { // Parse body and log method if jsonrpc path pathSlice := strings.Split(r.URL.Path, "/") if r.Method == "POST" && pathSlice[len(pathSlice)-1] == "jsonrpc" { - var jsonrpcRequest JSONRPCRequest - err = json.Unmarshal(body, &jsonrpcRequest) + jsonrpcRequests, err := jsonrpcRequestParser(body) if err != nil { - log.Printf("Unable to parse body at logging middleware, err: %v", err) + log.Println(err) + } + for i, jsonrpcRequest := range jsonrpcRequests { + if i == 0 { + logStr += fmt.Sprintf(" [%s", jsonrpcRequest.Method) + } else { + logStr += fmt.Sprintf(" %s", jsonrpcRequest.Method) + } + if i == len(jsonrpcRequests)-1 { + logStr += fmt.Sprint("]") + } } - logStr += fmt.Sprintf(" %s", jsonrpcRequest.Method) } if stateCLI.enableDebugFlag { diff --git a/nodes/node_balancer/cmd/nodebalancer/routes.go b/nodes/node_balancer/cmd/nodebalancer/routes.go index 0d390c39..31e376e2 100644 --- a/nodes/node_balancer/cmd/nodebalancer/routes.go +++ b/nodes/node_balancer/cmd/nodebalancer/routes.go @@ -97,9 +97,9 @@ func lbJSONRPCHandler(w http.ResponseWriter, r *http.Request, blockchain string, } r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) - var jsonrpcRequest JSONRPCRequest - err = json.Unmarshal(body, &jsonrpcRequest) + jsonrpcRequests, err := jsonrpcRequestParser(body) if err != nil { + log.Println(err) http.Error(w, "Unable to parse JSON RPC request", http.StatusBadRequest) return } @@ -111,10 +111,12 @@ func lbJSONRPCHandler(w http.ResponseWriter, r *http.Request, blockchain string, return } if currentClientAccess.ExtendedMethods == false { - _, exists := ALLOWED_METHODS[jsonrpcRequest.Method] - if !exists { - http.Error(w, "Method for provided access id not allowed", http.StatusForbidden) - return + for _, jsonrpcRequest := range jsonrpcRequests { + _, exists := ALLOWED_METHODS[jsonrpcRequest.Method] + if !exists { + http.Error(w, "Method for provided access id not allowed", http.StatusForbidden) + return + } } }