Skip to content

Commit 517d6f9

Browse files
committed
feat: apply source filtering for msearch
1 parent dab7b87 commit 517d6f9

1 file changed

Lines changed: 49 additions & 15 deletions

File tree

plugins/elasticsearch/middleware.go

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io/ioutil"
88
"log"
99
"net/http"
10+
"strings"
1011

1112
"github.com/appbaseio/arc/middleware"
1213
"github.com/appbaseio/arc/middleware/classify"
@@ -130,12 +131,12 @@ func transformRequest(h http.HandlerFunc) http.HandlerFunc {
130131
}
131132
// transform POST request(search) to GET
132133
if *reqACL == category.Search {
133-
req.Method = http.MethodGet
134+
isMsearch := strings.HasSuffix(req.URL.String(), "/_msearch")
134135
// Apply source filters
135136
reqPermission, err := permission.FromContext(ctx)
136137
if err != nil {
137138
log.Printf("%s: %v\n", logTag, err)
138-
util.WriteBackError(w, err.Error(), http.StatusInternalServerError)
139+
h(w, req)
139140
return
140141
}
141142
sources := make(map[string]interface{})
@@ -148,22 +149,55 @@ func transformRequest(h http.HandlerFunc) http.HandlerFunc {
148149
if len(Excludes) > 0 {
149150
sources["excludes"] = Excludes
150151
}
151-
body, err := ioutil.ReadAll(req.Body)
152-
if err != nil {
153-
log.Printf("%s: %v\n", logTag, err)
154-
util.WriteBackError(w, err.Error(), http.StatusInternalServerError)
155-
return
156-
}
157-
d := json.NewDecoder(ioutil.NopCloser(bytes.NewReader(body)))
158-
reqBody := make(map[string]interface{})
159-
d.Decode(&reqBody)
160152
_, isExcludesPresent := sources["excludes"]
161153
isDefaultInclude := len(Includes) > 0 && Includes[0] == "*"
162-
if !isDefaultInclude || isExcludesPresent {
163-
reqBody["_source"] = sources
154+
shouldApplyFilters := !isDefaultInclude || isExcludesPresent
155+
if shouldApplyFilters {
156+
if isMsearch {
157+
// Handle the _msearch requests
158+
body, err := ioutil.ReadAll(req.Body)
159+
if err != nil {
160+
log.Printf("%s: %v\n", logTag, err)
161+
util.WriteBackError(w, err.Error(), http.StatusInternalServerError)
162+
return
163+
}
164+
var reqBodyString = string(body)
165+
splitReq := strings.Split(reqBodyString, "\n")
166+
var modifiedBodyString string
167+
for index, element := range splitReq {
168+
if index%2 == 1 { // even lines
169+
var reqBody = make(map[string]interface{})
170+
err := json.Unmarshal([]byte(element), &reqBody)
171+
if err != nil {
172+
log.Printf("%s: %v\n", logTag, err)
173+
util.WriteBackError(w, err.Error(), http.StatusInternalServerError)
174+
return
175+
}
176+
reqBody["_source"] = sources
177+
raw, _ := json.Marshal(reqBody)
178+
modifiedBodyString += string(raw)
179+
} else {
180+
modifiedBodyString += element
181+
}
182+
modifiedBodyString += "\n"
183+
}
184+
modifiedBody := []byte(modifiedBodyString)
185+
req.Body = ioutil.NopCloser(bytes.NewReader(modifiedBody))
186+
} else {
187+
body, err := ioutil.ReadAll(req.Body)
188+
if err != nil {
189+
log.Printf("%s: %v\n", logTag, err)
190+
util.WriteBackError(w, err.Error(), http.StatusInternalServerError)
191+
return
192+
}
193+
d := json.NewDecoder(ioutil.NopCloser(bytes.NewReader(body)))
194+
reqBody := make(map[string]interface{})
195+
d.Decode(&reqBody)
196+
reqBody["_source"] = sources
197+
modifiedBody, _ := json.Marshal(reqBody)
198+
req.Body = ioutil.NopCloser(bytes.NewReader(modifiedBody))
199+
}
164200
}
165-
modifiedBody, _ := json.Marshal(reqBody)
166-
req.Body = ioutil.NopCloser(bytes.NewReader(modifiedBody))
167201
}
168202
h(w, req)
169203
}

0 commit comments

Comments
 (0)