Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ public enum LogKeys implements LogKey {
HOST_LOCAL_BLOCKS_SIZE,
HOST_PORT,
HOST_PORT2,
HTTP_METHOD,
HTTP_QUERY_STRING,
HTTP_REFERER,
HTTP_STATUS_CODE,
HTTP_USER_AGENT,
HUGE_METHOD_LIMIT,
HYBRID_STORE_DISK_BACKEND,
IDENTIFIER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.zip.ZipOutputStream
import scala.util.control.NonFatal
import scala.xml.Node

import jakarta.servlet.Filter
import jakarta.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
import org.eclipse.jetty.ee10.servlet.{ServletContextHandler, ServletHolder}

Expand Down Expand Up @@ -74,6 +75,14 @@ class HistoryServer(
// and its metrics, for testing as well as monitoring
val cacheMetrics = appCache.metrics

override protected def getInternalFilters: Seq[() => Filter] = {
if (conf.get(History.HISTORY_SERVER_UI_ACCESS_LOG_ENABLED)) {
Seq(() => new HistoryServerAccessLogFilter(conf))
} else {
Nil
}
}

private val loaderServlet = new HttpServlet {
protected override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.deploy.history

import java.net.URLDecoder
import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.TimeUnit

import scala.util.control.NonFatal

import jakarta.servlet.{Filter, FilterChain, ServletRequest, ServletResponse}
import jakarta.servlet.http.{HttpServletRequest, HttpServletResponse, HttpServletResponseWrapper}

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.config.{History, SECRET_REDACTION_PATTERN}
import org.apache.spark.util.Utils

private[history] class HistoryServerAccessLogFilter(conf: SparkConf)
extends Filter with Logging {

import HistoryServerAccessLogFilter._

private val redactionPattern = Some(conf.get(SECRET_REDACTION_PATTERN))

private val excludedPathPrefixes = conf.get(History.HISTORY_SERVER_UI_ACCESS_LOG_EXCLUDE_PATHS)
.map(normalizePathPrefix)
.filter(_.nonEmpty)

override def doFilter(
request: ServletRequest,
response: ServletResponse,
chain: FilterChain): Unit = {
(request, response) match {
case (httpRequest: HttpServletRequest, httpResponse: HttpServletResponse)
if !shouldSkip(httpRequest) =>
doFilterWithAccessLog(httpRequest, httpResponse, chain)

case _ =>
chain.doFilter(request, response)
}
}

private def doFilterWithAccessLog(
request: HttpServletRequest,
response: HttpServletResponse,
chain: FilterChain): Unit = {
val responseWrapper = new StatusCaptureResponse(response)
val startNs = System.nanoTime()
var error: Throwable = null
try {
chain.doFilter(request, responseWrapper)
} catch {
case NonFatal(e) =>
error = e
throw e
} finally {
logAccess(request, responseWrapper, startNs, Option(error))
}
}

private def shouldSkip(request: HttpServletRequest): Boolean = {
val requestPath = Option(request.getRequestURI).getOrElse("")
excludedPathPrefixes.exists(matchesPathPrefix(requestPath, _))
}

private def logAccess(
request: HttpServletRequest,
response: StatusCaptureResponse,
startNs: Long,
error: Option[Throwable]): Unit = {
val durationMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
val status = if (error.isDefined && response.status == HttpServletResponse.SC_OK) {
HttpServletResponse.SC_INTERNAL_SERVER_ERROR
} else {
response.status
}
val errorClass = error.map(_.getClass.getName).getOrElse(MissingField)

logInfo(log"Spark History Server access" +
log" method=${MDC(HTTP_METHOD, field(request.getMethod))}" +
log" uri=${MDC(URI, field(redact(request.getRequestURI)))}" +
log" query=${MDC(HTTP_QUERY_STRING, redactQueryString(request.getQueryString))}" +
log" status=${MDC(HTTP_STATUS_CODE, status)}" +
log" durationMs=${MDC(DURATION, durationMs)}" +
log" remoteAddress=${MDC(REMOTE_ADDRESS, field(request.getRemoteAddr))}" +
log" user=${MDC(USER_NAME, field(remoteUser(request)))}" +
log" userAgent=${MDC(HTTP_USER_AGENT, field(redact(request.getHeader("User-Agent"))))}" +
log" referer=${MDC(HTTP_REFERER, field(redact(request.getHeader("Referer"))))}" +
log" error=${MDC(ERROR, errorClass)}")
}

private def remoteUser(request: HttpServletRequest): String = {
Option(request.getRemoteUser)
.orElse(Option(request.getUserPrincipal).map(_.getName))
.orNull
}

private def redact(value: String): String = {
Option(value).map(Utils.redact(redactionPattern, _)).orNull
}

private def redactQueryString(queryString: String): String = {
Option(queryString)
.filter(_.nonEmpty)
.map(_.split("&", -1).map(redactQueryParam).mkString("&"))
.getOrElse(MissingField)
}

private def redactQueryParam(param: String): String = {
val separator = param.indexOf('=')
if (separator < 0) {
val decodedParam = decodeQueryComponent(param)
val redactedParam = Utils.redact(redactionPattern, decodedParam)
if (redactedParam != decodedParam) {
Utils.REDACTION_REPLACEMENT_TEXT
} else {
field(param)
}
} else {
val rawKey = param.substring(0, separator)
val rawValue = param.substring(separator + 1)
val decodedKey = decodeQueryComponent(rawKey)
val decodedValue = decodeQueryComponent(rawValue)
val redactedValue = Utils.redact(redactionPattern, Seq(decodedKey -> decodedValue)).head._2
val valueToLog = if (redactedValue == Utils.REDACTION_REPLACEMENT_TEXT) {
Utils.REDACTION_REPLACEMENT_TEXT
} else {
rawValue
}
s"${field(rawKey)}=${field(valueToLog)}"
}
}
}

private[history] object HistoryServerAccessLogFilter {

private val MissingField = "-"

private def normalizePathPrefix(path: String): String = {
val trimmed = path.trim
if (trimmed.isEmpty) {
""
} else {
val withLeadingSlash = if (trimmed.startsWith("/")) trimmed else s"/$trimmed"
if (withLeadingSlash == "/") withLeadingSlash else withLeadingSlash.stripSuffix("/")
}
}

private def matchesPathPrefix(requestPath: String, prefix: String): Boolean = {
prefix == "/" || requestPath == prefix || requestPath.startsWith(s"$prefix/")
}

private def decodeQueryComponent(value: String): String = {
try {
URLDecoder.decode(value, UTF_8.name())
} catch {
case NonFatal(_) => value
}
}

private def field(value: String): String = {
Option(value)
.filter(_.nonEmpty)
.map { v =>
v.map {
case c if Character.isWhitespace(c) || Character.isISOControl(c) => '_'
case c => c
}.mkString
}
.getOrElse(MissingField)
}

private class StatusCaptureResponse(response: HttpServletResponse)
extends HttpServletResponseWrapper(response) {

private var _status = HttpServletResponse.SC_OK

def status: Int = _status

override def setStatus(sc: Int): Unit = {
_status = sc
super.setStatus(sc)
}

override def sendError(sc: Int): Unit = {
_status = sc
super.sendError(sc)
}

override def sendError(sc: Int, msg: String): Unit = {
_status = sc
super.sendError(sc, msg)
}

override def sendRedirect(location: String): Unit = {
_status = HttpServletResponse.SC_FOUND
super.sendRedirect(location)
}
}
}
21 changes: 21 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/History.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,27 @@ private[spark] object History {
.intConf
.createWithDefault(18080)

val HISTORY_SERVER_UI_ACCESS_LOG_ENABLED =
ConfigBuilder("spark.history.ui.accessLog.enabled")
.doc("Whether the History Server should log HTTP access records for its web UI and REST " +
"API. When enabled, each non-excluded request is logged at INFO level after it " +
"completes. Query string values are redacted using spark.redaction.regex.")
.version("4.3.0")
.withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE)
.booleanConf
.createWithDefault(false)

val HISTORY_SERVER_UI_ACCESS_LOG_EXCLUDE_PATHS =
ConfigBuilder("spark.history.ui.accessLog.excludePaths")
.doc("Comma-separated list of request path prefixes to exclude from History Server " +
"HTTP access logs. This can be used to avoid logging high-volume low-value requests " +
"such as static resources.")
.version("4.3.0")
.withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE)
.stringConf
.toSequence
.createWithDefault(Seq("/static", "/favicon.ico"))

val FAST_IN_PROGRESS_PARSING =
ConfigBuilder("spark.history.fs.inProgressOptimization.enabled")
.doc("Enable optimized handling of in-progress logs. This option may leave finished " +
Expand Down
23 changes: 21 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,17 @@ private[spark] object JettyUtils extends Logging {
conf: SparkConf,
serverName: String = "",
poolSize: Int = 200): ServerInfo = {
startJettyServer(hostName, port, sslOptions, conf, serverName, poolSize, Nil)
}

def startJettyServer(
hostName: String,
port: Int,
sslOptions: SSLOptions,
conf: SparkConf,
serverName: String,
poolSize: Int,
internalFilters: Seq[() => Filter]): ServerInfo = {

val stopTimeout = conf.get(UI_JETTY_STOP_TIMEOUT)
logInfo(log"Start Jetty ${MDC(HOST, hostName)}:${MDC(PORT, port)}" +
Expand Down Expand Up @@ -381,7 +392,7 @@ private[spark] object JettyUtils extends Logging {

server.addConnector(httpConnector)
pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))
ServerInfo(server, httpPort, securePort, conf, collection)
ServerInfo(server, httpPort, securePort, conf, collection, internalFilters)
} catch {
case e: Exception =>
server.stop()
Expand Down Expand Up @@ -469,7 +480,8 @@ private[spark] case class ServerInfo(
boundPort: Int,
securePort: Option[Int],
private val conf: SparkConf,
private val rootHandler: ContextHandlerCollection) extends Logging {
private val rootHandler: ContextHandlerCollection,
private val internalFilters: Seq[() => Filter] = Nil) extends Logging {

def addHandler(
handler: ServletContextHandler,
Expand Down Expand Up @@ -547,6 +559,13 @@ private[spark] case class ServerInfo(
JettyUtils.addFilter(handler, filter, oldParams ++ newParams)
}

// Internal filters run after user-installed filters so authentication wrappers are visible,
// and before the security filter so denied requests can still be observed.
internalFilters.foreach { filter =>
val holder = new FilterHolder(filter())
handler.addFilter(holder, "/*", EnumSet.of(DispatcherType.REQUEST))
}

// This filter must come after user-installed filters, since that's where authentication
// filters are installed. This means that custom filters will see the request before it's
// been validated by the security filter.
Expand Down
7 changes: 5 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/WebUI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.xml.Node

import jakarta.servlet.DispatcherType
import jakarta.servlet.{DispatcherType, Filter}
import jakarta.servlet.http.{HttpServlet, HttpServletRequest}
import org.eclipse.jetty.ee10.servlet.{FilterHolder, FilterMapping, ServletContextHandler, ServletHolder}
import org.json4s.JsonAST.{JNothing, JValue}
Expand Down Expand Up @@ -137,13 +137,16 @@ private[spark] abstract class WebUI(
attachHandler(JettyUtils.createStaticHandler(resourceBase, path))
}

protected def getInternalFilters: Seq[() => Filter] = Nil

/** A hook to initialize components of the UI */
def initialize(): Unit

def initServer(): ServerInfo = {
val hostName = Option(conf.getenv("SPARK_LOCAL_IP"))
.getOrElse(if (Utils.preferIPv6) "[::]" else "0.0.0.0")
val server = startJettyServer(hostName, port, sslOptions, conf, name, poolSize)
val server = startJettyServer(
hostName, port, sslOptions, conf, name, poolSize, getInternalFilters)
server
}

Expand Down
Loading