package com.imed.costaccount.common.xss; import cn.hutool.core.util.StrUtil; import cn.hutool.http.HtmlUtil; import cn.hutool.json.JSONUtil; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import java.io.*; import java.nio.charset.Charset; import java.util.LinkedHashMap; import java.util.Map; public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper { public XssHttpServletRequestWrapper(HttpServletRequest request) { super(request); } @Override public String getParameter(String name) { String value = super.getParameter(name); if (!StrUtil.hasEmpty(value)) { value = HtmlUtil.filter(value); } return value; } @Override public String[] getParameterValues(String name) { String[] values = super.getParameterValues(name); if (values != null) { for (int i = 0; i < values.length; i++) { String value = values[i]; if (!StrUtil.hasEmpty(value)) { value = HtmlUtil.filter(value); } values[i] = value; } } return values; } @Override public Map getParameterMap() { Map parameters = super.getParameterMap(); LinkedHashMap map = new LinkedHashMap(); if (parameters != null) { for (String key : parameters.keySet()) { String[] values = parameters.get(key); for (int i = 0; i < values.length; i++) { String value = values[i]; if (!StrUtil.hasEmpty(value)) { value = HtmlUtil.filter(value); } values[i] = value; } map.put(key, values); } } return map; } @Override public String getHeader(String name) { String value = super.getHeader(name); if (!StrUtil.hasEmpty(value)) { value = HtmlUtil.filter(value); } return value; } @Override public ServletInputStream getInputStream() throws IOException { InputStream in = super.getInputStream(); InputStreamReader reader = new InputStreamReader(in, Charset.forName("UTF-8")); BufferedReader buffer = new BufferedReader(reader); StringBuffer body = new StringBuffer(); String line = buffer.readLine(); while (line != null) { body.append(line); line = buffer.readLine(); } buffer.close(); reader.close(); in.close(); Map map = JSONUtil.parseObj(body.toString()); Map result = new LinkedHashMap<>(); for (String key : map.keySet()) { Object val = map.get(key); if (val instanceof String) { if (!StrUtil.hasEmpty(val.toString())) { result.put(key, HtmlUtil.filter(val.toString())); } } else { result.put(key, val); } } String json = JSONUtil.toJsonStr(result); ByteArrayInputStream bain = new ByteArrayInputStream(json.getBytes()); return new ServletInputStream() { @Override public int read() throws IOException { return bain.read(); } @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return false; } @Override public void setReadListener(ReadListener readListener) { } }; } }