Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue 2485 which occur oom when using async servlet request. #3440

Merged
merged 4 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
import com.alibaba.csp.sentinel.slots.block.BlockException;
import com.alibaba.csp.sentinel.util.AssertUtil;
import com.alibaba.csp.sentinel.util.StringUtil;

import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.AsyncHandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

import javax.servlet.http.HttpServletRequest;
Expand All @@ -50,11 +52,11 @@
* return mav;
* }
* </pre>
*
*
* @author kaizi2009
* @since 1.7.1
*/
public abstract class AbstractSentinelInterceptor implements HandlerInterceptor {
public abstract class AbstractSentinelInterceptor implements AsyncHandlerInterceptor {

public static final String SENTINEL_SPRING_WEB_CONTEXT_NAME = "sentinel_spring_web_context";
private static final String EMPTY_ORIGIN = "";
Expand All @@ -66,12 +68,12 @@ public AbstractSentinelInterceptor(BaseWebMvcConfig config) {
AssertUtil.assertNotBlank(config.getRequestAttributeName(), "requestAttributeName should not be blank");
this.baseWebMvcConfig = config;
}

/**
* @param request
* @param rcKey
* @param step
* @return reference count after increasing (initial value as zero to be increased)
* @return reference count after increasing (initial value as zero to be increased)
*/
private Integer increaseReference(HttpServletRequest request, String rcKey, int step) {
Object obj = request.getAttribute(rcKey);
Expand All @@ -85,10 +87,10 @@ private Integer increaseReference(HttpServletRequest request, String rcKey, int
request.setAttribute(rcKey, newRc);
return newRc;
}

@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
throws Exception {
throws Exception {
try {
String resourceName = getResourceName(request);

Expand All @@ -99,7 +101,7 @@ public boolean preHandle(HttpServletRequest request, HttpServletResponse respons
if (increaseReference(request, this.baseWebMvcConfig.getRequestRefName(), 1) != 1) {
return true;
}

// Parse the request origin using registered origin parser.
String origin = parseOrigin(request);
String contextName = getContextName(request);
Expand Down Expand Up @@ -135,21 +137,45 @@ protected String getContextName(HttpServletRequest request) {
return SENTINEL_SPRING_WEB_CONTEXT_NAME;
}


/**
* When a handler starts an asynchronous request, the DispatcherServlet exits without invoking postHandle and afterCompletion
* Called instead of postHandle and afterCompletion to exit the context and clean thread-local variables when the handler is being executed concurrently.
*
* @param request the current request
* @param response the current response
* @param handler the handler (or {@link HandlerMethod}) that started async
* execution, for type and/or instance examination
*/
@Override
public void afterConcurrentHandlingStarted(HttpServletRequest request, HttpServletResponse response,
Object handler) throws Exception {
exit(request);
}

@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response,
Object handler, Exception ex) throws Exception {
exit(request, ex);
}

private void exit(HttpServletRequest request) {
exit(request, null);
}

private void exit(HttpServletRequest request, Exception ex) {
if (increaseReference(request, this.baseWebMvcConfig.getRequestRefName(), -1) != 0) {
return;
}

Entry entry = getEntryInRequest(request, baseWebMvcConfig.getRequestAttributeName());
if (entry == null) {
// should not happen
RecordLog.warn("[{}] No entry found in request, key: {}",
getClass().getSimpleName(), baseWebMvcConfig.getRequestAttributeName());
return;
}

traceExceptionAndExit(entry, ex);
removeEntryInRequest(request);
ContextUtil.exit();
Expand All @@ -162,7 +188,7 @@ public void postHandle(HttpServletRequest request, HttpServletResponse response,

protected Entry getEntryInRequest(HttpServletRequest request, String attrKey) {
Object entryObject = request.getAttribute(attrKey);
return entryObject == null ? null : (Entry)entryObject;
return entryObject == null ? null : (Entry) entryObject;
}

protected void removeEntryInRequest(HttpServletRequest request) {
Expand All @@ -188,7 +214,7 @@ && increaseReference(request, this.baseWebMvcConfig.getRequestRefName() + ":" +
}

protected void handleBlockException(HttpServletRequest request, HttpServletResponse response, BlockException e)
throws Exception {
throws Exception {
if (baseWebMvcConfig.getBlockExceptionHandler() != null) {
baseWebMvcConfig.getBlockExceptionHandler().handle(request, response, e);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

import com.alibaba.csp.sentinel.context.ContextUtil;
import com.alibaba.csp.sentinel.node.ClusterNode;
import com.alibaba.csp.sentinel.slots.block.RuleConstant;
import com.alibaba.csp.sentinel.slots.block.degrade.DegradeRule;
Expand Down Expand Up @@ -66,6 +68,18 @@ public void testBase() throws Exception {
assertEquals(1, cn.passQps(), 0.01);
}

@Test
public void testAsync() throws Exception {
String url = "/async";
this.mvc.perform(get(url))
.andExpect(status().isOk());

ClusterNode cn = ClusterBuilderSlot.getClusterNode(url);
assertNotNull(cn);
assertEquals(1, cn.passQps(), 0.01);
assertNull(ContextUtil.getContext());
}

@Test
public void testOriginParser() throws Exception {
String springMvcPathVariableUrl = "/foo/{id}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import com.alibaba.csp.sentinel.adapter.spring.webmvc.exception.BizException;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.request.async.DeferredResult;

/**
* @author kaizi2009
Expand Down Expand Up @@ -58,4 +60,16 @@ public String apiExclude(@PathVariable("id") Long id) {
return "Exclude " + id;
}

@GetMapping("/async")
@ResponseBody
public DeferredResult<String> distribute() throws Exception{
DeferredResult<String> result = new DeferredResult<>();

Thread thread = new Thread(() -> result.setResult("async result."));
thread.start();

Thread.yield();
return result;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
*/
package com.alibaba.csp.sentinel.adapter.spring.webmvc_v6x;

import com.alibaba.csp.sentinel.*;
import com.alibaba.csp.sentinel.Entry;
import com.alibaba.csp.sentinel.EntryType;
import com.alibaba.csp.sentinel.ResourceTypeConstants;
import com.alibaba.csp.sentinel.SphU;
import com.alibaba.csp.sentinel.Tracer;
import com.alibaba.csp.sentinel.adapter.spring.webmvc_v6x.config.BaseWebMvcConfig;
import com.alibaba.csp.sentinel.context.ContextUtil;
import com.alibaba.csp.sentinel.log.RecordLog;
Expand All @@ -24,7 +28,8 @@
import com.alibaba.csp.sentinel.util.StringUtil;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.AsyncHandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

/**
Expand All @@ -45,7 +50,7 @@
*
* @since 1.8.8
*/
public abstract class AbstractSentinelInterceptor implements HandlerInterceptor {
public abstract class AbstractSentinelInterceptor implements AsyncHandlerInterceptor {

public static final String SENTINEL_SPRING_WEB_CONTEXT_NAME = "sentinel_spring_web_context";
private static final String EMPTY_ORIGIN = "";
Expand Down Expand Up @@ -124,9 +129,33 @@ protected String getContextName(HttpServletRequest request) {
return SENTINEL_SPRING_WEB_CONTEXT_NAME;
}


/**
* When a handler starts an asynchronous request, the DispatcherServlet exits without invoking postHandle and afterCompletion
* Called instead of postHandle and afterCompletion to exit the context and clean thread-local variables when the handler is being executed concurrently.
*
* @param request the current request
* @param response the current response
* @param handler the handler (or {@link HandlerMethod}) that started async
* execution, for type and/or instance examination
*/
@Override
public void afterConcurrentHandlingStarted(HttpServletRequest request, HttpServletResponse response,
Object handler) throws Exception {
exit(request);
}

@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response,
Object handler, Exception ex) throws Exception {
exit(request, ex);
}

private void exit(HttpServletRequest request) {
exit(request, null);
}

private void exit(HttpServletRequest request, Exception ex) {
if (increaseReference(request, this.baseWebMvcConfig.getRequestRefName(), -1) != 0) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

import com.alibaba.csp.sentinel.context.ContextUtil;
import com.alibaba.csp.sentinel.node.ClusterNode;
import com.alibaba.csp.sentinel.slots.block.RuleConstant;
import com.alibaba.csp.sentinel.slots.block.flow.FlowRule;
Expand Down Expand Up @@ -64,6 +66,18 @@ public void testBase() throws Exception {
assertEquals(1, cn.passQps(), 0.01);
}

@Test
public void testAsync() throws Exception {
String url = "/async";
this.mvc.perform(get(url))
.andExpect(status().isOk());

ClusterNode cn = ClusterBuilderSlot.getClusterNode(url);
assertNotNull(cn);
assertEquals(1, cn.passQps(), 0.01);
assertNull(ContextUtil.getContext());
}

@Test
public void testOriginParser() throws Exception {
String springMvcPathVariableUrl = "/foo/{id}";
Expand All @@ -78,7 +92,7 @@ public void testOriginParser() throws Exception {

// This will be blocked since the caller is same: userA
this.mvc.perform(
get("/foo/2").accept(MediaType.APPLICATION_JSON).header(headerName, limitOrigin))
get("/foo/2").accept(MediaType.APPLICATION_JSON).header(headerName, limitOrigin))
.andExpect(status().isOk())
.andExpect(content().json(ResultWrapper.blocked().toJsonString()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.request.async.DeferredResult;

/**
* @author kaizi2009
Expand Down Expand Up @@ -52,4 +54,16 @@ public String apiExclude(@PathVariable("id") Long id) {
return "Exclude " + id;
}

@GetMapping("/async")
@ResponseBody
public DeferredResult<String> distribute() throws Exception {
DeferredResult<String> result = new DeferredResult<>();

Thread thread = new Thread(() -> result.setResult("async result."));
thread.start();

Thread.yield();
return result;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

import java.util.Random;
import java.util.concurrent.TimeUnit;

import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.servlet.ModelAndView;

/**
* Test controller
*
* @author kaizi2009
*/
@Controller
Expand Down Expand Up @@ -57,14 +60,25 @@ public String apiExclude(@PathVariable("id") Long id) {
doBusiness();
return "Exclude " + id;
}

@GetMapping("/forward")
public ModelAndView apiForward() {
ModelAndView mav = new ModelAndView();
mav.setViewName("hello");
return mav;
}

@GetMapping("/async")
@ResponseBody
public DeferredResult<String> distribute() throws Exception {
DeferredResult<String> result = new DeferredResult<>(4000L);

Thread thread = new Thread(() -> result.setResult("async result"));
thread.start();

return result;
}

private void doBusiness() {
Random random = new Random(1);
try {
Expand Down
Loading