Skip to content

Commit

Permalink
Fix issue 185 - request cannot rewind during retry (#186)
Browse files Browse the repository at this point in the history
* Ensure the proxied request is type that is rewindable

* Add comment on why we copy request body

* Fix test by dealing with nil cases
  • Loading branch information
alvinlin123 committed Jan 26, 2024
1 parent 5770ebf commit 64c8f0b
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions handler/proxy_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package handler
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httputil"
Expand Down Expand Up @@ -119,6 +120,14 @@ func chunked(transferEncoding []string) bool {
return false
}

func readDownStreamRequestBody(req *http.Request) ([]byte, error) {
if req.Body == nil {
return []byte{}, nil
}
defer req.Body.Close()
return io.ReadAll(req.Body)
}

func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) {
proxyURL := *req.URL
if p.HostOverride != "" {
Expand All @@ -140,7 +149,16 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) {
log.WithField("request", string(initialReqDump)).Debug("Initial request dump:")
}

proxyReq, err := http.NewRequest(req.Method, proxyURL.String(), req.Body)
// Save the request body into memory so that it's rewindable during retry.
// See https://github.com/awslabs/aws-sigv4-proxy/issues/185
// This may increase memory demand, but the demand should be ok for most cases. If there
// are cases proven to be very problematic, we can consider adding a flag to disable this.
proxyReqBody, err := readDownStreamRequestBody(req)
if err != nil {
return nil, err
}

proxyReq, err := http.NewRequest(req.Method, proxyURL.String(), bytes.NewReader(proxyReqBody))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -222,15 +240,15 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) {
}

if (p.LogFailedRequest || log.GetLevel() == log.DebugLevel) && resp.StatusCode >= 400 {
b, _ := ioutil.ReadAll(resp.Body)
b, _ := io.ReadAll(resp.Body)
log.WithField("request", fmt.Sprintf("%s %s", proxyReq.Method, proxyReq.URL)).
WithField("status_code", resp.StatusCode).
WithField("message", string(b)).
Error("error proxying request")

// Need to "reset" the response body because we consumed the stream above, otherwise caller will
// get empty body.
resp.Body = ioutil.NopCloser(bytes.NewBuffer(b))
resp.Body = io.NopCloser(bytes.NewBuffer(b))
}

return resp, nil
Expand Down

0 comments on commit 64c8f0b

Please sign in to comment.