diff --git a/src/main/java/com/hubspot/jinjava/JinjavaConfig.java b/src/main/java/com/hubspot/jinjava/JinjavaConfig.java index 3cff4787c..343df63a8 100644 --- a/src/main/java/com/hubspot/jinjava/JinjavaConfig.java +++ b/src/main/java/com/hubspot/jinjava/JinjavaConfig.java @@ -88,6 +88,7 @@ public class JinjavaConfig { private final ExecutionMode executionMode; private final LegacyOverrides legacyOverrides; private final boolean enablePreciseDivideFilter; + private final boolean enableFilterChainOptimization; private final ObjectMapper objectMapper; private final Features features; @@ -151,6 +152,7 @@ private JinjavaConfig(Builder builder) { legacyOverrides = builder.legacyOverrides; dateTimeProvider = builder.dateTimeProvider; enablePreciseDivideFilter = builder.enablePreciseDivideFilter; + enableFilterChainOptimization = builder.enableFilterChainOptimization; objectMapper = setupObjectMapper(builder.objectMapper); objectUnwrapper = builder.objectUnwrapper; processors = builder.processors; @@ -307,6 +309,10 @@ public boolean getEnablePreciseDivideFilter() { return enablePreciseDivideFilter; } + public boolean isEnableFilterChainOptimization() { + return enableFilterChainOptimization; + } + public DateTimeProvider getDateTimeProvider() { return dateTimeProvider; } @@ -349,6 +355,7 @@ public static class Builder { private ExecutionMode executionMode = DefaultExecutionMode.instance(); private LegacyOverrides legacyOverrides = LegacyOverrides.NONE; private boolean enablePreciseDivideFilter = false; + private boolean enableFilterChainOptimization = false; private ObjectMapper objectMapper = null; private ObjectUnwrapper objectUnwrapper = new JinjavaObjectUnwrapper(); @@ -520,6 +527,13 @@ public Builder withEnablePreciseDivideFilter(boolean enablePreciseDivideFilter) return this; } + public Builder withEnableFilterChainOptimization( + boolean enableFilterChainOptimization + ) { + this.enableFilterChainOptimization = enableFilterChainOptimization; + return this; + } + public Builder withObjectMapper(ObjectMapper objectMapper) { this.objectMapper = objectMapper; return this; diff --git a/src/main/java/com/hubspot/jinjava/el/ext/AstFilterChain.java b/src/main/java/com/hubspot/jinjava/el/ext/AstFilterChain.java new file mode 100644 index 000000000..618b87905 --- /dev/null +++ b/src/main/java/com/hubspot/jinjava/el/ext/AstFilterChain.java @@ -0,0 +1,212 @@ +package com.hubspot.jinjava.el.ext; + +import com.hubspot.jinjava.interpret.DisabledException; +import com.hubspot.jinjava.interpret.JinjavaInterpreter; +import com.hubspot.jinjava.interpret.TemplateError; +import com.hubspot.jinjava.interpret.TemplateError.ErrorItem; +import com.hubspot.jinjava.interpret.TemplateError.ErrorReason; +import com.hubspot.jinjava.interpret.TemplateError.ErrorType; +import com.hubspot.jinjava.lib.filter.Filter; +import com.hubspot.jinjava.objects.SafeString; +import de.odysseus.el.tree.Bindings; +import de.odysseus.el.tree.impl.ast.AstNode; +import de.odysseus.el.tree.impl.ast.AstParameters; +import de.odysseus.el.tree.impl.ast.AstRightValue; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import javax.el.ELContext; +import javax.el.ELException; + +/** + * AST node for a chain of filters applied to an input expression. + * Instead of creating nested AstMethod calls for each filter in a chain like: + * filter:length.filter(filter:lower.filter(filter:trim.filter(input))) + * + * This node represents the entire chain as a single evaluation unit: + * input|trim|lower|length + * + * This optimization reduces: + * - Filter lookups (done once per filter instead of per AST node traversal) + * - Method invocation overhead + * - Object wrapping/unwrapping between filters + * - Context operations + */ +public class AstFilterChain extends AstRightValue { + + protected final AstNode input; + protected final List filterSpecs; + + public AstFilterChain(AstNode input, List filterSpecs) { + this.input = Objects.requireNonNull(input, "Input node cannot be null"); + this.filterSpecs = Objects.requireNonNull(filterSpecs, "Filter specs cannot be null"); + if (filterSpecs.isEmpty()) { + throw new IllegalArgumentException("Filter chain must have at least one filter"); + } + } + + public AstNode getInput() { + return input; + } + + public List getFilterSpecs() { + return filterSpecs; + } + + @Override + public Object eval(Bindings bindings, ELContext context) { + JinjavaInterpreter interpreter = getInterpreter(context); + + if (interpreter.getContext().isValidationMode()) { + return ""; + } + + Object value = input.eval(bindings, context); + + for (FilterSpec spec : filterSpecs) { + String filterKey = ExtendedParser.FILTER_PREFIX + spec.getName(); + interpreter.getContext().addResolvedValue(filterKey); + + Filter filter; + try { + filter = interpreter.getContext().getFilter(spec.getName()); + } catch (DisabledException e) { + interpreter.addError( + new TemplateError( + ErrorType.FATAL, + ErrorReason.DISABLED, + ErrorItem.FILTER, + e.getMessage(), + spec.getName(), + interpreter.getLineNumber(), + -1, + e + ) + ); + return null; + } + if (filter == null) { + continue; + } + + Object[] args = evaluateFilterArgs(spec, bindings, context); + Map kwargs = extractNamedParams(args); + Object[] positionalArgs = extractPositionalArgs(args); + + boolean wasSafeString = value instanceof SafeString; + if (wasSafeString) { + value = value.toString(); + } + + try { + value = filter.filter(value, interpreter, positionalArgs, kwargs); + } catch (ELException e) { + throw e; + } catch (RuntimeException e) { + throw new ELException( + String.format("Error in filter '%s': %s", spec.getName(), e.getMessage()), + e + ); + } + + if (wasSafeString && filter.preserveSafeString() && value instanceof String) { + value = new SafeString((String) value); + } + } + + return value; + } + + protected JinjavaInterpreter getInterpreter(ELContext context) { + return (JinjavaInterpreter) context + .getELResolver() + .getValue(context, null, ExtendedParser.INTERPRETER); + } + + protected Object[] evaluateFilterArgs( + FilterSpec spec, + Bindings bindings, + ELContext context + ) { + AstParameters params = spec.getParams(); + if (params == null || params.getCardinality() == 0) { + return new Object[0]; + } + + Object[] args = new Object[params.getCardinality()]; + for (int i = 0; i < params.getCardinality(); i++) { + args[i] = params.getChild(i).eval(bindings, context); + } + return args; + } + + private Map extractNamedParams(Object[] args) { + Map kwargs = new LinkedHashMap<>(); + for (Object arg : args) { + if (arg instanceof NamedParameter) { + NamedParameter namedParam = (NamedParameter) arg; + kwargs.put(namedParam.getName(), namedParam.getValue()); + } + } + return kwargs; + } + + private Object[] extractPositionalArgs(Object[] args) { + List positional = new ArrayList<>(); + for (Object arg : args) { + if (!(arg instanceof NamedParameter)) { + positional.add(arg); + } + } + return positional.toArray(); + } + + @Override + public void appendStructure(StringBuilder builder, Bindings bindings) { + input.appendStructure(builder, bindings); + for (FilterSpec spec : filterSpecs) { + builder.append('|').append(spec.getName()); + AstParameters params = spec.getParams(); + if (params != null && params.getCardinality() > 0) { + builder.append('('); + for (int i = 0; i < params.getCardinality(); i++) { + if (i > 0) { + builder.append(", "); + } + params.getChild(i).appendStructure(builder, bindings); + } + builder.append(')'); + } + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(input.toString()); + for (FilterSpec spec : filterSpecs) { + sb.append('|').append(spec.toString()); + } + return sb.toString(); + } + + @Override + public int getCardinality() { + return 1 + filterSpecs.size(); + } + + @Override + public AstNode getChild(int i) { + if (i == 0) { + return input; + } + int filterIndex = i - 1; + if (filterIndex < filterSpecs.size()) { + FilterSpec spec = filterSpecs.get(filterIndex); + return spec.getParams(); + } + return null; + } +} diff --git a/src/main/java/com/hubspot/jinjava/el/ext/ExtendedParser.java b/src/main/java/com/hubspot/jinjava/el/ext/ExtendedParser.java index 4f7f741a0..543726a7b 100644 --- a/src/main/java/com/hubspot/jinjava/el/ext/ExtendedParser.java +++ b/src/main/java/com/hubspot/jinjava/el/ext/ExtendedParser.java @@ -531,30 +531,11 @@ protected AstNode value() throws ScanException, ParseException { private AstNode parseOperators(AstNode left) throws ScanException, ParseException { if ("|".equals(getToken().getImage()) && lookahead(0).getSymbol() == IDENTIFIER) { - AstNode v = left; - - do { - consumeToken(); // '|' - String filterName = consumeToken().getImage(); - List filterParams = Lists.newArrayList(v, interpreter()); - - // optional filter args - if (getToken().getSymbol() == Symbol.LPAREN) { - AstParameters astParameters = params(); - for (int i = 0; i < astParameters.getCardinality(); i++) { - filterParams.add(astParameters.getChild(i)); - } - } - - AstProperty filterProperty = createAstDot( - identifier(FILTER_PREFIX + filterName), - "filter", - true - ); - v = createAstMethod(filterProperty, createAstParameters(filterParams)); // function("filter:" + filterName, new AstParameters(filterParams)); - } while ("|".equals(getToken().getImage())); - - return v; + if (shouldUseFilterChainOptimization()) { + return parseFiltersAsChain(left); + } else { + return parseFiltersAsNestedMethods(left); + } } else if ( "is".equals(getToken().getImage()) && "not".equals(lookahead(0).getImage()) && @@ -577,6 +558,68 @@ protected AstParameters createAstParameters(List nodes) { return new AstParameters(nodes); } + protected AstFilterChain createAstFilterChain( + AstNode input, + List filterSpecs + ) { + return new AstFilterChain(input, filterSpecs); + } + + private AstNode parseFiltersAsChain(AstNode left) throws ScanException, ParseException { + List filterSpecs = new ArrayList<>(); + + do { + consumeToken(); // '|' + String filterName = consumeToken().getImage(); + AstParameters filterParams = null; + + // optional filter args + if (getToken().getSymbol() == Symbol.LPAREN) { + filterParams = params(); + } + + filterSpecs.add(new FilterSpec(filterName, filterParams)); + } while ("|".equals(getToken().getImage())); + + return createAstFilterChain(left, filterSpecs); + } + + protected AstNode parseFiltersAsNestedMethods(AstNode left) + throws ScanException, ParseException { + AstNode v = left; + + do { + consumeToken(); // '|' + String filterName = consumeToken().getImage(); + List filterParams = Lists.newArrayList(v, interpreter()); + + // optional filter args + if (getToken().getSymbol() == Symbol.LPAREN) { + AstParameters astParameters = params(); + for (int i = 0; i < astParameters.getCardinality(); i++) { + filterParams.add(astParameters.getChild(i)); + } + } + + AstProperty filterProperty = createAstDot( + identifier(FILTER_PREFIX + filterName), + "filter", + true + ); + v = createAstMethod(filterProperty, createAstParameters(filterParams)); + } while ("|".equals(getToken().getImage())); + + return v; + } + + protected boolean shouldUseFilterChainOptimization() { + return JinjavaInterpreter + .getCurrentMaybe() + .map(JinjavaInterpreter::getConfig) + .map(JinjavaConfig::isEnableFilterChainOptimization) + .orElse(false); + } + private boolean isPossibleExpTest(Symbol symbol) { return VALID_SYMBOLS_FOR_EXP_TEST.contains(symbol); } diff --git a/src/main/java/com/hubspot/jinjava/el/ext/FilterSpec.java b/src/main/java/com/hubspot/jinjava/el/ext/FilterSpec.java new file mode 100644 index 000000000..175016913 --- /dev/null +++ b/src/main/java/com/hubspot/jinjava/el/ext/FilterSpec.java @@ -0,0 +1,48 @@ +package com.hubspot.jinjava.el.ext; + +import de.odysseus.el.tree.impl.ast.AstParameters; +import java.util.Objects; + +/** + * Specification for a filter in a filter chain. + * Holds the filter name and optional parameters. + */ +public class FilterSpec { + + private final String name; + private final AstParameters params; + + public FilterSpec(String name, AstParameters params) { + this.name = Objects.requireNonNull(name, "Filter name cannot be null"); + this.params = params; + } + + public String getName() { + return name; + } + + public AstParameters getParams() { + return params; + } + + public boolean hasParams() { + return params != null && params.getCardinality() > 0; + } + + @Override + public String toString() { + if (hasParams()) { + StringBuilder sb = new StringBuilder(name); + sb.append('('); + for (int i = 0; i < params.getCardinality(); i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(params.getChild(i)); + } + sb.append(')'); + return sb.toString(); + } + return name; + } +} diff --git a/src/main/java/com/hubspot/jinjava/el/ext/eager/EagerExtendedParser.java b/src/main/java/com/hubspot/jinjava/el/ext/eager/EagerExtendedParser.java index 5ca383afd..29f53e2b7 100644 --- a/src/main/java/com/hubspot/jinjava/el/ext/eager/EagerExtendedParser.java +++ b/src/main/java/com/hubspot/jinjava/el/ext/eager/EagerExtendedParser.java @@ -198,4 +198,9 @@ protected AstList createAstList(AstParameters parameters) protected AstParameters createAstParameters(List nodes) { return new EagerAstParameters(nodes); } + + @Override + protected boolean shouldUseFilterChainOptimization() { + return false; + } } diff --git a/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainPerformanceTest.java b/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainPerformanceTest.java new file mode 100644 index 000000000..5ed90f2fa --- /dev/null +++ b/src/test/java/com/hubspot/jinjava/el/ext/AstFilterChainPerformanceTest.java @@ -0,0 +1,244 @@ +package com.hubspot.jinjava.el.ext; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.hubspot.jinjava.Jinjava; +import com.hubspot.jinjava.JinjavaConfig; +import java.util.HashMap; +import java.util.Map; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Performance test to verify that the optimized filter chain performs better than + * the nested method approach. + * + * Run with: mvn test -Dtest=AstFilterChainPerformanceTest + * Or run the main() method directly for more detailed output. + */ +public class AstFilterChainPerformanceTest { + + private Jinjava jinjavaOptimized; + private Jinjava jinjavaUnoptimized; + private Map context; + + @Before + public void setup() { + jinjavaOptimized = + new Jinjava( + JinjavaConfig.newBuilder().withEnableFilterChainOptimization(true).build() + ); + + jinjavaUnoptimized = + new Jinjava( + JinjavaConfig.newBuilder().withEnableFilterChainOptimization(false).build() + ); + + context = new HashMap<>(); + context.put("name", " Hello World "); + context.put("text", "the quick brown fox jumps over the lazy dog"); + context.put("number", 12345); + context.put("items", new String[] { "apple", "banana", "cherry" }); + context.put("content", Map.of("text", "the quick brown fox jumps over the lazy dog")); + } + + public static void main(String[] args) { + AstFilterChainPerformanceTest test = new AstFilterChainPerformanceTest(); + test.setup(); + test.runPerformanceComparison(); + } + + /** + * Run this test manually to see detailed performance comparison. + * Use main() method or run with -Dtest=AstFilterChainPerformanceTest#runPerformanceComparison + */ + @Test + @Ignore("Manual performance test - run explicitly when needed") + public void runPerformanceComparison() { + int warmupIterations = 10000; + int testIterations = 100000; + + System.out.println("=== Filter Chain Performance Test ===\n"); + System.out.println("Warming up..."); + + // Warmup + runFilterTests(jinjavaOptimized, warmupIterations, false); + runFilterTests(jinjavaUnoptimized, warmupIterations, false); + + System.out.println( + "Running performance tests with " + testIterations + " iterations each\n" + ); + + // Single filter + comparePerformance("Single filter: {{ name|trim }}", testIterations); + + // Two chained filters + comparePerformance("Two filters: {{ name|trim|lower }}", testIterations); + + // Three chained filters + comparePerformance("Three filters: {{ name|trim|lower|capitalize }}", testIterations); + + // Five chained filters + comparePerformance( + "Five filters: {{ text|upper|replace('THE', 'a')|trim|lower|title }}", + testIterations + ); + + // Filter with arguments + comparePerformance( + "Filters with args: {{ text|truncate(20)|upper }}", + testIterations + ); + + // Multiple filter chains in same template + comparePerformance( + "Multiple chains: {{ name|trim|lower }} and {{ text|upper|truncate(10) }}", + testIterations + ); + } + + private void comparePerformance(String description, int iterations) { + String template = description.substring(description.indexOf("{{")); + if (description.contains(":")) { + template = description.substring(description.indexOf(":") + 2); + } + + System.out.println(description); + + // Run optimized + long optimizedTime = timeExecution(jinjavaOptimized, template, iterations); + + // Run unoptimized + long unoptimizedTime = timeExecution(jinjavaUnoptimized, template, iterations); + + double speedup = (double) unoptimizedTime / optimizedTime; + System.out.printf( + " Optimized: %d ms, Unoptimized: %d ms, Speedup: %.2fx%n%n", + optimizedTime, + unoptimizedTime, + speedup + ); + } + + private long timeExecution(Jinjava jinjava, String template, int iterations) { + long startTime = System.currentTimeMillis(); + for (int i = 0; i < iterations; i++) { + jinjava.render(template, context); + } + return System.currentTimeMillis() - startTime; + } + + private void runFilterTests(Jinjava jinjava, int iterations, boolean print) { + String[] templates = { + "{{ name|trim }}", + "{{ name|trim|lower }}", + "{{ name|trim|lower|capitalize }}", + "{{ text|upper|replace('THE', 'a')|trim|lower|title }}", + "{{ text|truncate(20)|upper }}", + }; + + for (String template : templates) { + for (int i = 0; i < iterations; i++) { + jinjava.render(template, context); + } + } + } + + @Test + public void itProducesSameResultsWithAndWithoutOptimization() { + String[] templates = { + "{{ name|trim }}", + "{{ name|trim|lower }}", + "{{ name|trim|lower|capitalize }}", + "{{ text|upper|replace('THE', 'a')|trim|lower|title }}", + "{{ text|truncate(20)|upper }}", + "{{ name|trim|lower }} and {{ text|upper|truncate(10) }}", + "{{ items|join(', ')|upper }}", + "{{ number|string|length }}", + }; + + for (String template : templates) { + String optimizedResult = jinjavaOptimized.render(template, context); + String unoptimizedResult = jinjavaUnoptimized.render(template, context); + assertThat(optimizedResult) + .as("Template: " + template) + .isEqualTo(unoptimizedResult); + } + } + + @Test + public void itHandlesSingleFilterWithOptimization() { + String result = jinjavaOptimized.render("{{ name|trim }}", context); + assertThat(result).isEqualTo("Hello World"); + } + + @Test + public void itHandlesChainedFiltersWithOptimization() { + String result = jinjavaOptimized.render("{{ name|trim|lower }}", context); + assertThat(result).isEqualTo("hello world"); + } + + @Test + public void itHandlesFiltersWithArgumentsWithOptimization() { + String result = jinjavaOptimized.render("{{ text|truncate(20)|upper }}", context); + assertThat(result).isNotEmpty(); + assertThat(result).isUpperCase(); + } + + @Test + public void itHandlesComplexFilterChainWithOptimization() { + String result = jinjavaOptimized.render( + "{{ text|upper|replace('THE', 'a')|trim|lower|capitalize }}", + context + ); + assertThat(result).isNotEmpty(); + } + + /** + * This test verifies that the optimized version is faster than the unoptimized version. + * The optimization should provide a measurable speedup for chained filters. + */ + @Test + public void optimizedVersionShouldBeFaster() { + int warmupIterations = 100; + int testIterations = 1000; + String template = "{{ content.text|upper|replace('THE', 'a')|trim|lower|title }}"; + + // Warmup both to ensure JIT compilation + for (int i = 0; i < warmupIterations; i++) { + jinjavaOptimized.render(template, context); + jinjavaUnoptimized.render(template, context); + } + + // Run multiple rounds to get more stable results + long totalOptimizedTime = 0; + long totalUnoptimizedTime = 0; + int rounds = 3; + + for (int round = 0; round < rounds; round++) { + totalUnoptimizedTime += timeExecution(jinjavaUnoptimized, template, testIterations); + totalOptimizedTime += timeExecution(jinjavaOptimized, template, testIterations); + } + + long avgUnoptimizedTime = totalUnoptimizedTime / rounds; + long avgOptimizedTime = totalOptimizedTime / rounds; + + System.out.printf( + "Performance test: Optimized=%d ms, Unoptimized=%d ms, Speedup=%.2fx%n", + avgOptimizedTime, + avgUnoptimizedTime, + (double) avgUnoptimizedTime / avgOptimizedTime + ); + + // The optimized version should be faster (allow 10% margin for system variance) + // If optimized takes more than 90% of unoptimized time, fail the test + assertThat(avgOptimizedTime) + .as( + "Optimized (%d ms) should be faster than unoptimized (%d ms)", + avgOptimizedTime, + avgUnoptimizedTime + ) + .isLessThan((long) (avgUnoptimizedTime * 0.95)); + } +}