From de50816edccbd3db988d7874883a16cd902029c0 Mon Sep 17 00:00:00 2001 From: kompotkot Date: Wed, 21 Feb 2024 08:00:10 +0000 Subject: [PATCH] Fix cors and starknet sepolia support --- nodebalancer/cmd/nodebalancer/balancer.go | 4 +-- nodebalancer/cmd/nodebalancer/configs.go | 4 +++ nodebalancer/cmd/nodebalancer/middleware.go | 32 ++++++++++++++------- nodebalancer/cmd/nodebalancer/server.go | 2 +- 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/nodebalancer/cmd/nodebalancer/balancer.go b/nodebalancer/cmd/nodebalancer/balancer.go index 1496b049..558ee593 100644 --- a/nodebalancer/cmd/nodebalancer/balancer.go +++ b/nodebalancer/cmd/nodebalancer/balancer.go @@ -198,7 +198,7 @@ func (bpool *BlockchainPool) HealthCheck() { for _, b := range bpool.Blockchains { var timeout time.Duration getLatestBlockReq := `{"jsonrpc":"2.0","method":"eth_getBlockByNumber","params":["latest", false],"id":1}` - if b.Blockchain == "starknet" || b.Blockchain == "starknet-goerli" { + if b.Blockchain == "starknet" || b.Blockchain == "starknet-goerli" || b.Blockchain == "starknet-sepolia" { getLatestBlockReq = `{"jsonrpc":"2.0","method":"starknet_getBlockWithTxHashes","params":["latest"],"id":"0"}` timeout = NB_HEALTH_CHECK_CALL_TIMEOUT * 2 } @@ -241,7 +241,7 @@ func (bpool *BlockchainPool) HealthCheck() { } var blockNumber uint64 - if b.Blockchain == "starknet" || b.Blockchain == "starknet-goerli" { + if b.Blockchain == "starknet" || b.Blockchain == "starknet-goerli" || b.Blockchain == "starknet-sepolia" { blockNumber = statusResponse.Result.BlockNumber } else { blockNumberHex := strings.Replace(statusResponse.Result.Number, "0x", "", -1) diff --git a/nodebalancer/cmd/nodebalancer/configs.go b/nodebalancer/cmd/nodebalancer/configs.go index e606dc83..d3991328 100644 --- a/nodebalancer/cmd/nodebalancer/configs.go +++ b/nodebalancer/cmd/nodebalancer/configs.go @@ -35,6 +35,7 @@ var ( NB_CONTROLLER_TOKEN = os.Getenv("NB_CONTROLLER_TOKEN") NB_CONTROLLER_ACCESS_ID = os.Getenv("NB_CONTROLLER_ACCESS_ID") MOONSTREAM_CORS_ALLOWED_ORIGINS = os.Getenv("MOONSTREAM_CORS_ALLOWED_ORIGINS") + CORS_WHITELIST_MAP = make(map[string]bool) NB_CONNECTION_RETRIES = 2 NB_CONNECTION_RETRIES_INTERVAL = time.Millisecond * 10 @@ -86,6 +87,9 @@ func CheckEnvVarSet() { NB_CONTROLLER_ACCESS_ID = uuid.New().String() log.Printf("Access ID for internal usage in NB_CONTROLLER_ACCESS_ID environment variable is not valid uuid, generated random one: %v", NB_CONTROLLER_ACCESS_ID) } + for _, o := range strings.Split(MOONSTREAM_CORS_ALLOWED_ORIGINS, ",") { + CORS_WHITELIST_MAP[o] = true + } } // Nodes configuration diff --git a/nodebalancer/cmd/nodebalancer/middleware.go b/nodebalancer/cmd/nodebalancer/middleware.go index e99ec037..64a0f363 100644 --- a/nodebalancer/cmd/nodebalancer/middleware.go +++ b/nodebalancer/cmd/nodebalancer/middleware.go @@ -368,19 +368,31 @@ func panicMiddleware(next http.Handler) http.Handler { // CORS middleware func corsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodOptions { - for _, allowedOrigin := range strings.Split(MOONSTREAM_CORS_ALLOWED_ORIGINS, ",") { - if r.Header.Get("Origin") == allowedOrigin { - w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) - w.Header().Set("Access-Control-Allow-Methods", "GET,POST") - // Credentials are cookies, authorization headers, or TLS client certificates - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") - } + var allowedOrigin string + if CORS_WHITELIST_MAP["*"] { + allowedOrigin = "*" + } else { + origin := r.Header.Get("Origin") + if _, ok := CORS_WHITELIST_MAP[origin]; ok { + allowedOrigin = origin } - w.WriteHeader(http.StatusNoContent) + } + + fmt.Println(allowedOrigin, CORS_WHITELIST_MAP) + + if allowedOrigin != "" { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS") + // Credentials are cookies, authorization headers, or TLS client certificates + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + } + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) return } + next.ServeHTTP(w, r) }) } diff --git a/nodebalancer/cmd/nodebalancer/server.go b/nodebalancer/cmd/nodebalancer/server.go index e1cbe155..e8a0cdc1 100644 --- a/nodebalancer/cmd/nodebalancer/server.go +++ b/nodebalancer/cmd/nodebalancer/server.go @@ -31,7 +31,7 @@ var ( func initHealthCheck(debug bool) { healthCheckInterval, convErr := strconv.Atoi(NB_HEALTH_CHECK_INTERVAL) if convErr != nil { - healthCheckInterval = 5 + healthCheckInterval = 30 } t := time.NewTicker(time.Second * time.Duration(healthCheckInterval)) for {