/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.abilities.source;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.connector.source.DynamicTableSource;
import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
import org.apache.flink.table.expressions.AggregateExpression;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction;
import org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction;
import org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction;
import org.apache.flink.table.planner.plan.abilities.source.SourceAbilityContext;
import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpecBase;
import org.apache.flink.table.planner.plan.utils.AggregateInfo;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.utils.TypeConversions;
import org.apache.flink.util.Preconditions;
import scala.Tuple2;

@JsonTypeName(value="AggregatePushDown")
public final class AggregatePushDownSpec
extends SourceAbilitySpecBase {
    public static final String FIELD_NAME_INPUT_TYPE = "inputType";
    public static final String FIELD_NAME_GROUPING_SETS = "groupingSets";
    public static final String FIELD_NAME_AGGREGATE_CALLS = "aggregateCalls";
    @JsonProperty(value="inputType")
    private final RowType inputType;
    @JsonProperty(value="groupingSets")
    private final List<int[]> groupingSets;
    @JsonProperty(value="aggregateCalls")
    private final List<AggregateCall> aggregateCalls;

    @JsonCreator
    public AggregatePushDownSpec(@JsonProperty(value="inputType") RowType inputType, @JsonProperty(value="groupingSets") List<int[]> groupingSets, @JsonProperty(value="aggregateCalls") List<AggregateCall> aggregateCalls, @JsonProperty(value="producedType") RowType producedType) {
        super(producedType);
        this.inputType = inputType;
        this.groupingSets = new ArrayList<int[]>((Collection)Preconditions.checkNotNull(groupingSets));
        this.aggregateCalls = aggregateCalls;
    }

    @Override
    public void apply(DynamicTableSource tableSource, SourceAbilityContext context) {
        Preconditions.checkArgument((boolean)this.getProducedType().isPresent());
        AggregatePushDownSpec.apply(this.inputType, this.groupingSets, this.aggregateCalls, this.getProducedType().get(), tableSource, context);
    }

    @Override
    public String getDigests(SourceAbilityContext context) {
        int[] grouping = this.groupingSets.get(0);
        String groupingStr = Arrays.stream(grouping).mapToObj(index -> (String)this.inputType.getFieldNames().get(index)).collect(Collectors.joining(","));
        List<AggregateExpression> aggregateExpressions = AggregatePushDownSpec.buildAggregateExpressions(this.inputType, this.aggregateCalls);
        String aggFunctionsStr = aggregateExpressions.stream().map(AggregateExpression::asSummaryString).collect(Collectors.joining(","));
        return "aggregates=[grouping=[" + groupingStr + "], aggFunctions=[" + aggFunctionsStr + "]]";
    }

    public static boolean apply(RowType inputType, List<int[]> groupingSets, List<AggregateCall> aggregateCalls, RowType producedType, DynamicTableSource tableSource, SourceAbilityContext context) {
        assert (context.isBatchMode() && groupingSets.size() == 1);
        List<AggregateExpression> aggregateExpressions = AggregatePushDownSpec.buildAggregateExpressions(inputType, aggregateCalls);
        if (tableSource instanceof SupportsAggregatePushDown) {
            DataType producedDataType = TypeConversions.fromLogicalToDataType((LogicalType)producedType);
            return ((SupportsAggregatePushDown)tableSource).applyAggregates(groupingSets, aggregateExpressions, producedDataType);
        }
        throw new TableException(String.format("%s does not support SupportsAggregatePushDown.", tableSource.getClass().getName()));
    }

    private static List<AggregateExpression> buildAggregateExpressions(RowType inputType, List<AggregateCall> aggregateCalls) {
        AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(inputType, JavaScalaConversionUtil.toScala(aggregateCalls), null, null);
        if (aggInfoList.aggInfos().length == 0) {
            return Collections.emptyList();
        }
        ArrayList<AggregateExpression> aggExpressions = new ArrayList<AggregateExpression>();
        for (AggregateInfo aggInfo : aggInfoList.aggInfos()) {
            ArrayList<FieldReferenceExpression> arguments = new ArrayList<FieldReferenceExpression>(1);
            for (int argIndex : aggInfo.argIndexes()) {
                DataType argType = TypeConversions.fromLogicalToDataType((LogicalType)((RowType.RowField)inputType.getFields().get(argIndex)).getType());
                FieldReferenceExpression field = new FieldReferenceExpression((String)inputType.getFieldNames().get(argIndex), argType, 0, argIndex);
                arguments.add(field);
            }
            if (aggInfo.function() instanceof AvgAggFunction) {
                Tuple2<Sum0AggFunction, CountAggFunction> sum0AndCountFunction = AggregateUtil.deriveSumAndCountFromAvg((AvgAggFunction)aggInfo.function());
                AggregateExpression sum0Expression = new AggregateExpression((FunctionDefinition)sum0AndCountFunction._1(), arguments, null, aggInfo.externalResultType(), aggInfo.agg().isDistinct(), aggInfo.agg().isApproximate(), aggInfo.agg().ignoreNulls());
                aggExpressions.add(sum0Expression);
                AggregateExpression countExpression = new AggregateExpression((FunctionDefinition)sum0AndCountFunction._2(), arguments, null, aggInfo.externalResultType(), aggInfo.agg().isDistinct(), aggInfo.agg().isApproximate(), aggInfo.agg().ignoreNulls());
                aggExpressions.add(countExpression);
                continue;
            }
            AggregateExpression aggregateExpression = new AggregateExpression((FunctionDefinition)aggInfo.function(), arguments, null, aggInfo.externalResultType(), aggInfo.agg().isDistinct(), aggInfo.agg().isApproximate(), aggInfo.agg().ignoreNulls());
            aggExpressions.add(aggregateExpression);
        }
        return aggExpressions;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        AggregatePushDownSpec that = (AggregatePushDownSpec)o;
        return Objects.equals(this.inputType, that.inputType) && Objects.equals(this.groupingSets, that.groupingSets) && Objects.equals(this.aggregateCalls, that.aggregateCalls);
    }

    @Override
    public int hashCode() {
        return Objects.hash(super.hashCode(), this.inputType, this.groupingSets, this.aggregateCalls);
    }
}

