|
33 | 33 | import org.apache.calcite.rex.LogicVisitor; |
34 | 34 | import org.apache.calcite.rex.RexBuilder; |
35 | 35 | import org.apache.calcite.rex.RexCorrelVariable; |
| 36 | +import org.apache.calcite.rex.RexFieldAccess; |
36 | 37 | import org.apache.calcite.rex.RexInputRef; |
37 | 38 | import org.apache.calcite.rex.RexLiteral; |
38 | 39 | import org.apache.calcite.rex.RexNode; |
|
47 | 48 | import org.apache.calcite.sql2rel.RelDecorrelator; |
48 | 49 | import org.apache.calcite.tools.RelBuilder; |
49 | 50 | import org.apache.calcite.util.ImmutableBitSet; |
| 51 | +import org.apache.calcite.util.Litmus; |
50 | 52 | import org.apache.calcite.util.Pair; |
51 | 53 |
|
52 | 54 | import com.google.common.collect.ImmutableList; |
| 55 | +import com.google.common.collect.ImmutableSet; |
53 | 56 | import com.google.common.collect.Iterables; |
54 | 57 |
|
55 | 58 | import org.immutables.value.Value; |
56 | 59 |
|
57 | 60 | import java.util.ArrayList; |
| 61 | +import java.util.HashMap; |
58 | 62 | import java.util.List; |
| 63 | +import java.util.Map; |
59 | 64 | import java.util.Set; |
60 | 65 | import java.util.stream.Collectors; |
61 | 66 |
|
@@ -967,10 +972,8 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) { |
967 | 972 | boolean inputIntersectsRightSide = |
968 | 973 | inputSet.intersects(ImmutableBitSet.range(nFieldsLeft, nFieldsLeft + nFieldsRight)); |
969 | 974 | if (inputIntersectsLeftSide && inputIntersectsRightSide) { |
970 | | - // The current existential rewrite needs to make join with one side of the origin join and |
971 | | - // generate a new condition to replace the on clause. But for RexNode whose operands are |
972 | | - // on either side of the join, we can't push them into join. So this rewriting is not |
973 | | - // supported. |
| 975 | + rewriteSubQueryOnDomain(rule, call, e, join, nFieldsLeft, nFieldsRight, |
| 976 | + inputSet, builder, variablesSet); |
974 | 977 | return; |
975 | 978 | } |
976 | 979 |
|
@@ -1079,6 +1082,232 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) { |
1079 | 1082 | call.transformTo(builder.build()); |
1080 | 1083 | } |
1081 | 1084 |
|
| 1085 | + /** |
| 1086 | + * Rewrites a sub-query that references columns from both the left and right inputs of a Join. |
| 1087 | + * |
| 1088 | + * <p>This method handles the complex case where a sub-query in a Join condition is correlated |
| 1089 | + * with both sides of the Join. It performs the following steps: |
| 1090 | + * <ol> |
| 1091 | + * <li>Identifies the "Domain" of values from the left and right inputs that are relevant |
| 1092 | + * to the sub-query.</li> |
| 1093 | + * <li>Constructs a "Computation Domain" by cross-joining the distinct keys from the left |
| 1094 | + * and right domains.</li> |
| 1095 | + * <li>Remaps the sub-query to operate on this Computation Domain.</li> |
| 1096 | + * <li>Rewrites the sub-query using the standard {@link #apply} method, but applied to the |
| 1097 | + * Domain.</li> |
| 1098 | + * <li>Re-integrates the result of the sub-query rewrite back into the original Join structure, |
| 1099 | + * ensuring correct join types and conditions are maintained.</li> |
| 1100 | + * </ol> |
| 1101 | + * |
| 1102 | + * @param rule The rule instance |
| 1103 | + * @param call The rule call |
| 1104 | + * @param e The sub-query to rewrite |
| 1105 | + * @param join The join containing the sub-query |
| 1106 | + * @param nFieldsLeft Number of fields in the left input |
| 1107 | + * @param nFieldsRight Number of fields in the right input |
| 1108 | + * @param inputSet BitSet of columns used by the sub-query |
| 1109 | + * @param builder The RelBuilder |
| 1110 | + * @param variablesSet Set of correlation variables used by the sub-query |
| 1111 | + */ |
| 1112 | + private static void rewriteSubQueryOnDomain(SubQueryRemoveRule rule, |
| 1113 | + RelOptRuleCall call, |
| 1114 | + RexSubQuery e, |
| 1115 | + Join join, |
| 1116 | + int nFieldsLeft, |
| 1117 | + int nFieldsRight, |
| 1118 | + ImmutableBitSet inputSet, |
| 1119 | + RelBuilder builder, |
| 1120 | + Set<CorrelationId> variablesSet) { |
| 1121 | + // Map to store the offset of each correlation variable |
| 1122 | + final Map<CorrelationId, Integer> idToOffset = new HashMap<>(); |
| 1123 | + // Helper to determine offset for each correlation variable |
| 1124 | + e.rel.accept(new CorrelationOffsetFinder(idToOffset, join, nFieldsLeft)); |
| 1125 | + |
| 1126 | + // 1. Identify which columns from Left and Right are used by the subquery. |
| 1127 | + // These will form the "Domain" on which the subquery is calculated. |
| 1128 | + final ImmutableBitSet leftUsed = |
| 1129 | + inputSet.intersect(ImmutableBitSet.range(0, nFieldsLeft)); |
| 1130 | + final ImmutableBitSet rightUsed = |
| 1131 | + inputSet.intersect(ImmutableBitSet.range(nFieldsLeft, nFieldsLeft + nFieldsRight)); |
| 1132 | + |
| 1133 | + // 2. Build the "Computation Domain". |
| 1134 | + // This is a Cross Join of the distinct keys from Left and Right. |
| 1135 | + // Domain = Distinct(Project(LeftUsed)) x Distinct(Project(RightUsed)) |
| 1136 | + |
| 1137 | + // 2a. Left Domain |
| 1138 | + builder.push(join.getLeft()); |
| 1139 | + builder.project(builder.fields(leftUsed)); |
| 1140 | + builder.distinct(); |
| 1141 | + |
| 1142 | + // 2b. Right Domain |
| 1143 | + builder.push(join.getRight()); |
| 1144 | + // We must shift the bitset to be 0-based for the Right input |
| 1145 | + ImmutableBitSet rightUsedShifted = rightUsed.shift(-nFieldsLeft); |
| 1146 | + builder.project(builder.fields(rightUsedShifted)); |
| 1147 | + builder.distinct(); |
| 1148 | + |
| 1149 | + // 2c. Create Domain Cross Join |
| 1150 | + builder.join(JoinRelType.INNER, builder.literal(true)); |
| 1151 | + |
| 1152 | + // 3. Remap the SubQuery to run on the Domain. |
| 1153 | + // We need to map original field indices to their new positions in the Domain. |
| 1154 | + // Original: [LeftFields... | RightFields...] |
| 1155 | + // Domain: [LeftUsed... | RightUsed...] |
| 1156 | + final Map<Integer, Integer> mapping = new HashMap<>(); |
| 1157 | + int targetIdx = 0; |
| 1158 | + for (int source : leftUsed) { |
| 1159 | + mapping.put(source, targetIdx++); |
| 1160 | + } |
| 1161 | + for (int source : rightUsed) { |
| 1162 | + mapping.put(source, targetIdx++); |
| 1163 | + } |
| 1164 | + |
| 1165 | + final RexBuilder rexBuilder = builder.getRexBuilder(); |
| 1166 | + final CorrelationId domainCorrId = join.getCluster().createCorrel(); |
| 1167 | + final RexNode domainCorrVar = rexBuilder.makeCorrel(builder.peek().getRowType(), domainCorrId); |
| 1168 | + |
| 1169 | + // Shuttle to replace InputRefs and Correlations with references to the Domain |
| 1170 | + RexShuttle shuttle = new InputRefAndCorrelationReplacer(mapping, variablesSet, idToOffset); |
| 1171 | + // Create the new subquery with operands remapped to the Domain |
| 1172 | + RexNode newSubQueryNode = e.accept(shuttle); |
| 1173 | + |
| 1174 | + // Rewrite e.rel to use domainCorrId |
| 1175 | + RelNode newRel = e.rel.accept(new DomainRewriter(variablesSet, idToOffset, mapping, |
| 1176 | + rexBuilder, domainCorrVar)); |
| 1177 | + |
| 1178 | + if (newSubQueryNode instanceof RexSubQuery) { |
| 1179 | + newSubQueryNode = ((RexSubQuery) newSubQueryNode).clone(newRel); |
| 1180 | + } |
| 1181 | + |
| 1182 | + // We introduced a new correlation variable domainCorrId. |
| 1183 | + Set<CorrelationId> newVariablesSet = ImmutableSet.of(domainCorrId); |
| 1184 | + |
| 1185 | + final RelOptUtil.Logic logic = |
| 1186 | + LogicVisitor.find(join.getJoinType().generatesNullsOnRight() |
| 1187 | + ? RelOptUtil.Logic.TRUE_FALSE_UNKNOWN : RelOptUtil.Logic.TRUE, |
| 1188 | + ImmutableList.of(join.getCondition()), e); |
| 1189 | + |
| 1190 | + // 4. Apply the standard rewriting rule to the Domain. |
| 1191 | + // The builder is currently sitting on the Domain Join. |
| 1192 | + // 'target' is the CASE expression (or similar) resulting from the rewrite. |
| 1193 | + // The builder stack now has the result of the rewrite (e.g. Domain Left Join Aggregate). |
| 1194 | + assert newSubQueryNode instanceof RexSubQuery; |
| 1195 | + final RexNode target = |
| 1196 | + rule.apply((RexSubQuery) newSubQueryNode, newVariablesSet, logic, builder, |
| 1197 | + 1, builder.peek().getRowType().getFieldCount(), 0); |
| 1198 | + |
| 1199 | + // The target references the Domain Result (which is currently at the top of the builder). |
| 1200 | + // In the final plan, the Domain Result will be joined to the right of the original inputs. |
| 1201 | + // Furthermore, since we use a LEFT JOIN, the Domain Result columns become nullable. |
| 1202 | + // So we need to shift the references in target AND make them nullable. |
| 1203 | + final int offset = nFieldsLeft + nFieldsRight; |
| 1204 | + final RexShuttle shiftAndNullableShuttle = new RexShuttle() { |
| 1205 | + @Override public RexNode visitInputRef(RexInputRef inputRef) { |
| 1206 | + // Shift the index |
| 1207 | + int newIndex = inputRef.getIndex() + offset; |
| 1208 | + return new RexInputRef(newIndex, inputRef.getType()); |
| 1209 | + } |
| 1210 | + }; |
| 1211 | + final RexNode shiftedTarget = target.accept(shiftAndNullableShuttle); |
| 1212 | + |
| 1213 | + // 5. Re-integrate with Original Inputs |
| 1214 | + // Stack has: [RewriteResult] |
| 1215 | + RelNode domainResult = builder.build(); |
| 1216 | + |
| 1217 | + // Rebuild the original Join structure |
| 1218 | + // We want to construct: Left JOIN (Right JOIN Domain) ON ... |
| 1219 | + // This preserves the JoinRelType of the original join. |
| 1220 | + JoinRelType joinType = join.getJoinType(); |
| 1221 | + if (joinType == JoinRelType.RIGHT) { |
| 1222 | + // Symmetric to LEFT/INNER/FULL but attached to Left |
| 1223 | + builder.push(join.getLeft()); |
| 1224 | + builder.push(domainResult); |
| 1225 | + |
| 1226 | + // Join Left and Domain on Left Keys |
| 1227 | + List<RexNode> leftJoinConditions = new ArrayList<>(); |
| 1228 | + int domainIdx = 0; // Left Keys are at start of Domain |
| 1229 | + for (int source : leftUsed) { |
| 1230 | + leftJoinConditions.add( |
| 1231 | + builder.equals( |
| 1232 | + builder.field(2, 0, source), |
| 1233 | + builder.field(2, 1, domainIdx++))); |
| 1234 | + } |
| 1235 | + builder.join(JoinRelType.INNER, builder.and(leftJoinConditions)); |
| 1236 | + |
| 1237 | + // Now Join Right |
| 1238 | + builder.push(join.getRight()); |
| 1239 | + // Stack: (Left+Domain), Right |
| 1240 | + |
| 1241 | + // Join Condition: Original + Right Keys match |
| 1242 | + List<RexNode> rightJoinConditions = new ArrayList<>(); |
| 1243 | + // Domain starts after Left. Right Keys in Domain are after Left Keys. |
| 1244 | + int domainRightKeyIdx = nFieldsLeft + leftUsed.cardinality(); |
| 1245 | + for (int source : rightUsed) { |
| 1246 | + // Right input (index 1) |
| 1247 | + RexInputRef field = builder.field(2, 1, source - nFieldsLeft); |
| 1248 | + // (Left+Domain) input (index 0) |
| 1249 | + RexInputRef field1 = builder.field(2, 0, domainRightKeyIdx++); |
| 1250 | + rightJoinConditions.add(builder.equals(field, field1)); |
| 1251 | + } |
| 1252 | + |
| 1253 | + RexShuttle replaceShuttle = new ReplaceSubQueryShuttle(e, shiftedTarget); |
| 1254 | + RexNode newJoinCondition = join.getCondition().accept(replaceShuttle); |
| 1255 | + |
| 1256 | + builder.join(joinType, builder.and(builder.and(rightJoinConditions), newJoinCondition)); |
| 1257 | + |
| 1258 | + builder.project(fields(builder, nFieldsLeft + nFieldsRight)); |
| 1259 | + } else { |
| 1260 | + // For INNER, LEFT, FULL join, we can attach Domain to Right, then Join Left. |
| 1261 | + // 1. Build (Right JOIN Domain) |
| 1262 | + builder.push(join.getRight()); |
| 1263 | + builder.push(domainResult); |
| 1264 | + |
| 1265 | + // Join Right and Domain on Right Keys |
| 1266 | + // Domain layout: [LeftKeys, RightKeys] |
| 1267 | + List<RexNode> rightJoinConditions = new ArrayList<>(); |
| 1268 | + // Skip Left Keys |
| 1269 | + int domainIdx = leftUsed.cardinality(); |
| 1270 | + for (int source : rightUsed) { |
| 1271 | + rightJoinConditions.add( |
| 1272 | + builder.equals( |
| 1273 | + builder.field(2, 0, source - nFieldsLeft), // Right input |
| 1274 | + builder.field(2, 1, domainIdx++))); // Domain input |
| 1275 | + } |
| 1276 | + // We use INNER join here to expand Right with Domain values. |
| 1277 | + // Since Domain contains all distinct Right keys, this is safe. |
| 1278 | + builder.join(JoinRelType.INNER, builder.and(rightJoinConditions)); |
| 1279 | + |
| 1280 | + // 2. Join Left with (Right JOIN Domain) |
| 1281 | + RelNode rightWithDomain = builder.build(); |
| 1282 | + builder.push(join.getLeft()); |
| 1283 | + builder.push(rightWithDomain); |
| 1284 | + |
| 1285 | + // Join Condition: Original Condition (rewritten) AND Left.LeftKeys = Domain.LeftKeys |
| 1286 | + List<RexNode> leftJoinConditions = new ArrayList<>(); |
| 1287 | + // In (Right+Domain), Domain fields start after Right fields |
| 1288 | + int domainStartInCombined = nFieldsRight; |
| 1289 | + int domainLeftKeyIdx = domainStartInCombined; // Left Keys are at start of Domain |
| 1290 | + |
| 1291 | + for (int source : leftUsed) { |
| 1292 | + // Left input |
| 1293 | + RexInputRef field = builder.field(2, 0, source); |
| 1294 | + // (Right+Domain) input |
| 1295 | + RexInputRef field1 = builder.field(2, 1, domainLeftKeyIdx++); |
| 1296 | + leftJoinConditions.add(builder.equals(field, field1)); |
| 1297 | + } |
| 1298 | + |
| 1299 | + RexShuttle replaceShuttle = new ReplaceSubQueryShuttle(e, shiftedTarget); |
| 1300 | + RexNode newJoinCondition = join.getCondition().accept(replaceShuttle); |
| 1301 | + |
| 1302 | + builder.join(joinType, builder.and(builder.and(leftJoinConditions), newJoinCondition)); |
| 1303 | + |
| 1304 | + // Project original fields (remove Domain columns) |
| 1305 | + builder.project(fields(builder, nFieldsLeft + nFieldsRight)); |
| 1306 | + } |
| 1307 | + |
| 1308 | + call.transformTo(builder.build()); |
| 1309 | + } |
| 1310 | + |
1082 | 1311 | private static void matchFilterEnableMarkJoin(SubQueryRemoveRule rule, RelOptRuleCall call) { |
1083 | 1312 | final Filter filter = call.rel(0); |
1084 | 1313 | final Set<CorrelationId> variablesSet = filter.getVariablesSet(); |
@@ -1212,6 +1441,125 @@ private static class ReplaceSubQueryShuttle extends RexShuttle { |
1212 | 1441 | return subQuery.equals(this.subQuery) ? replacement : subQuery; |
1213 | 1442 | } |
1214 | 1443 | } |
| 1444 | + |
| 1445 | + /** |
| 1446 | + * Shuttle that finds correlation variables and determines their offset. |
| 1447 | + */ |
| 1448 | + private static class CorrelationOffsetFinder extends RelHomogeneousShuttle { |
| 1449 | + private final Map<CorrelationId, Integer> idToOffset; |
| 1450 | + private final Join join; |
| 1451 | + private final int nFieldsLeft; |
| 1452 | + |
| 1453 | + CorrelationOffsetFinder(Map<CorrelationId, Integer> idToOffset, Join join, int nFieldsLeft) { |
| 1454 | + this.idToOffset = idToOffset; |
| 1455 | + this.join = join; |
| 1456 | + this.nFieldsLeft = nFieldsLeft; |
| 1457 | + } |
| 1458 | + |
| 1459 | + @Override public RelNode visit(RelNode other) { |
| 1460 | + other.accept(new RexShuttle() { |
| 1461 | + @Override public RexNode visitCorrelVariable(RexCorrelVariable correlVariable) { |
| 1462 | + if (!idToOffset.containsKey(correlVariable.id)) { |
| 1463 | + // Check if type matches Left |
| 1464 | + if (RelOptUtil.eq("type1", correlVariable.getType(), |
| 1465 | + "type2", join.getLeft().getRowType(), Litmus.IGNORE)) { |
| 1466 | + idToOffset.put(correlVariable.id, 0); |
| 1467 | + } else if (RelOptUtil.eq("type1", correlVariable.getType(), |
| 1468 | + "type2", join.getRight().getRowType(), Litmus.IGNORE)) { |
| 1469 | + idToOffset.put(correlVariable.id, nFieldsLeft); |
| 1470 | + } else { |
| 1471 | + // Default to 0 if unknown |
| 1472 | + idToOffset.put(correlVariable.id, 0); |
| 1473 | + } |
| 1474 | + } |
| 1475 | + return super.visitCorrelVariable(correlVariable); |
| 1476 | + } |
| 1477 | + }); |
| 1478 | + return super.visit(other); |
| 1479 | + } |
| 1480 | + } |
| 1481 | + |
| 1482 | + /** |
| 1483 | + * Shuttle that replaces InputRefs and Correlations with references to the Domain. |
| 1484 | + */ |
| 1485 | + private static class InputRefAndCorrelationReplacer extends RexShuttle { |
| 1486 | + private final Map<Integer, Integer> mapping; |
| 1487 | + private final Set<CorrelationId> variablesSet; |
| 1488 | + private final Map<CorrelationId, Integer> idToOffset; |
| 1489 | + |
| 1490 | + InputRefAndCorrelationReplacer(Map<Integer, Integer> mapping, |
| 1491 | + Set<CorrelationId> variablesSet, Map<CorrelationId, Integer> idToOffset) { |
| 1492 | + this.mapping = mapping; |
| 1493 | + this.variablesSet = variablesSet; |
| 1494 | + this.idToOffset = idToOffset; |
| 1495 | + } |
| 1496 | + |
| 1497 | + @Override public RexNode visitInputRef(RexInputRef inputRef) { |
| 1498 | + Integer newIndex = mapping.get(inputRef.getIndex()); |
| 1499 | + if (newIndex != null) { |
| 1500 | + return new RexInputRef(newIndex, inputRef.getType()); |
| 1501 | + } |
| 1502 | + return super.visitInputRef(inputRef); |
| 1503 | + } |
| 1504 | + |
| 1505 | + @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { |
| 1506 | + RexNode refExpr = fieldAccess.getReferenceExpr(); |
| 1507 | + if (refExpr instanceof RexCorrelVariable) { |
| 1508 | + CorrelationId id = ((RexCorrelVariable) refExpr).id; |
| 1509 | + if (variablesSet.contains(id)) { |
| 1510 | + int fieldIdx = fieldAccess.getField().getIndex(); |
| 1511 | + int offset = idToOffset.getOrDefault(id, 0); |
| 1512 | + Integer newIndex = mapping.get(fieldIdx + offset); |
| 1513 | + if (newIndex != null) { |
| 1514 | + return new RexInputRef(newIndex, fieldAccess.getType()); |
| 1515 | + } |
| 1516 | + } |
| 1517 | + } |
| 1518 | + return super.visitFieldAccess(fieldAccess); |
| 1519 | + } |
| 1520 | + } |
| 1521 | + |
| 1522 | + /** |
| 1523 | + * Shuttle that rewrites RelNodes to use the Domain correlation variable. |
| 1524 | + */ |
| 1525 | + private static class DomainRewriter extends RelHomogeneousShuttle { |
| 1526 | + private final Set<CorrelationId> variablesSet; |
| 1527 | + private final Map<CorrelationId, Integer> idToOffset; |
| 1528 | + private final Map<Integer, Integer> mapping; |
| 1529 | + private final RexBuilder rexBuilder; |
| 1530 | + private final RexNode domainCorrVar; |
| 1531 | + |
| 1532 | + DomainRewriter(Set<CorrelationId> variablesSet, Map<CorrelationId, Integer> idToOffset, |
| 1533 | + Map<Integer, Integer> mapping, RexBuilder rexBuilder, RexNode domainCorrVar) { |
| 1534 | + this.variablesSet = variablesSet; |
| 1535 | + this.idToOffset = idToOffset; |
| 1536 | + this.mapping = mapping; |
| 1537 | + this.rexBuilder = rexBuilder; |
| 1538 | + this.domainCorrVar = domainCorrVar; |
| 1539 | + } |
| 1540 | + |
| 1541 | + @Override public RelNode visit(RelNode other) { |
| 1542 | + return super.visit( |
| 1543 | + other.accept(new RexShuttle() { |
| 1544 | + @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { |
| 1545 | + RexNode refExpr = fieldAccess.getReferenceExpr(); |
| 1546 | + if (refExpr instanceof RexCorrelVariable) { |
| 1547 | + CorrelationId id = ((RexCorrelVariable) refExpr).id; |
| 1548 | + if (variablesSet.contains(id)) { |
| 1549 | + int fieldIdx = fieldAccess.getField().getIndex(); |
| 1550 | + int offset = idToOffset.getOrDefault(id, 0); |
| 1551 | + Integer newIndex = mapping.get(fieldIdx + offset); |
| 1552 | + if (newIndex != null) { |
| 1553 | + return rexBuilder.makeFieldAccess(domainCorrVar, newIndex); |
| 1554 | + } |
| 1555 | + } |
| 1556 | + } |
| 1557 | + return super.visitFieldAccess(fieldAccess); |
| 1558 | + } |
| 1559 | + })); |
| 1560 | + } |
| 1561 | + } |
| 1562 | + |
1215 | 1563 | /** Rule configuration. */ |
1216 | 1564 | @Value.Immutable(singleton = false) |
1217 | 1565 | public interface Config extends RelRule.Config { |
|
0 commit comments