/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
import org.apache.tsfile.read.common.type.LongType;
import org.apache.tsfile.read.common.type.Type;

public class TransformCorrelatedDistinctAggregationWithProjection
implements Rule<CorrelatedJoinNode> {
    private static final Capture<ProjectNode> PROJECTION = Capture.newCapture();
    private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Patterns.CorrelatedJoin.type().equalTo(JoinNode.JoinType.LEFT)).with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo(BooleanLiteral.TRUE_LITERAL)).with(Patterns.CorrelatedJoin.subquery().matching(Patterns.project().capturedAs(PROJECTION).with(Patterns.source().matching(Patterns.aggregation().matching(AggregationDecorrelation::isDistinctOperator).capturedAs(AGGREGATION)))));
    private final PlannerContext plannerContext;

    public TransformCorrelatedDistinctAggregationWithProjection(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(this.plannerContext, context.getSymbolAllocator(), context.getLookup());
        Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(captures.get(AGGREGATION).getChild(), correlatedJoinNode.getCorrelation());
        if (!decorrelatedSource.isPresent()) {
            return Rule.Result.empty();
        }
        PlanNode source = decorrelatedSource.get().getNode();
        AssignUniqueId inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().genPlanNodeId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", (Type)LongType.getInstance()));
        JoinNode join = new JoinNode(context.getIdAllocator().genPlanNodeId(), JoinNode.JoinType.LEFT, inputWithUniqueId, source, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), Optional.empty(), ((PlanNode)inputWithUniqueId).getOutputSymbols(), source.getOutputSymbols(), decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty());
        AggregationNode aggregation = captures.get(AGGREGATION);
        aggregation = new AggregationNode(aggregation.getPlanNodeId(), join, aggregation.getAggregations(), AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.builder().addAll(join.getLeftOutputSymbols()).addAll(aggregation.getGroupingKeys()).build()), (List<Symbol>)ImmutableList.of(), aggregation.getStep(), Optional.empty(), Optional.empty());
        HashSet<Symbol> outputSymbols = new HashSet<Symbol>(correlatedJoinNode.getOutputSymbols());
        List expectedAggregationOutputs = (List)aggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(ImmutableList.toImmutableList());
        Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().genPlanNodeId(), aggregation, assignments));
    }
}

