package tech.ydb.spark.connector.write;

import java.util.HashMap;
import java.util.concurrent.CompletableFuture;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.connector.write.DataWriterFactory;
import org.apache.spark.sql.connector.write.LogicalWriteInfo;
import org.apache.spark.sql.connector.write.PhysicalWriteInfo;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.Iterator;
import tech.ydb.core.Status;
import tech.ydb.proto.topic.YdbTopic;
import tech.ydb.spark.connector.YdbTable;
import tech.ydb.spark.connector.YdbTypes;
import tech.ydb.spark.connector.common.FieldInfo;
import tech.ydb.spark.connector.common.FieldType;
import tech.ydb.spark.connector.common.IngestMethod;
import tech.ydb.spark.connector.common.OperationOption;
import tech.ydb.table.query.Params;
import tech.ydb.table.values.ListValue;

/* loaded from: input_file:tech/ydb/spark/connector/write/YdbWriterFactory.class */
public class YdbWriterFactory implements DataWriterFactory {
    private static final long serialVersionUID = -6000846276376311177L;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) YdbWriterFactory.class);
    private final YdbTable table;
    private final YdbTypes types;
    private final StructType schema;
    private final IngestMethod method;
    private final String autoPkName;
    private final int maxBatchSize;

    public YdbWriterFactory(YdbTable ydbTable, LogicalWriteInfo logicalWriteInfo, PhysicalWriteInfo physicalWriteInfo) {
        this.table = ydbTable;
        this.types = new YdbTypes(logicalWriteInfo.options());
        this.method = (IngestMethod) OperationOption.INGEST_METHOD.readEnum(logicalWriteInfo.options(), IngestMethod.BULK_UPSERT);
        this.maxBatchSize = OperationOption.BATCH_SIZE.readInt(logicalWriteInfo.options(), YdbTopic.Codec.CODEC_CUSTOM_VALUE);
        this.autoPkName = OperationOption.AUTO_PK.read(logicalWriteInfo.options(), OperationOption.DEFAULT_AUTO_PK);
        this.schema = logicalWriteInfo.schema();
    }

    public DataWriter<InternalRow> createWriter(int i, long j) {
        logger.debug("New writer for table {}, partition {}, task {}", this.table.getTablePath(), Integer.valueOf(i), Long.valueOf(j));
        HashMap hashMap = new HashMap();
        for (FieldInfo fieldInfo : this.table.getAllColumns()) {
            hashMap.put(fieldInfo.getName(), fieldInfo);
        }
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        Iterator iterator = this.schema.toIterator();
        int i2 = 0;
        while (iterator.hasNext()) {
            StructField structField = (StructField) iterator.next();
            String name = structField.name();
            if (!hashMap.containsKey(name)) {
                throw new IllegalArgumentException("Cannot write column " + name + " to table " + this.table);
            }
            FieldInfo fieldInfo2 = (FieldInfo) hashMap.get(name);
            hashMap2.put(name, fieldInfo2.toYdbType());
            int i3 = i2;
            i2++;
            hashMap3.put(name, new ColumnReader(i3, structField.dataType(), fieldInfo2.getType()));
        }
        if (hashMap.containsKey(this.autoPkName)) {
            FieldInfo fieldInfo3 = (FieldInfo) hashMap.get(this.autoPkName);
            if (fieldInfo3.getType() != FieldType.Text) {
                throw new IllegalArgumentException("Wrong type of autopk column " + this.autoPkName + " -> " + fieldInfo3.getType());
            }
            hashMap2.put(this.autoPkName, fieldInfo3.toYdbType());
            hashMap3.put(this.autoPkName, new RandomReader());
        }
        tech.ydb.table.values.StructType of = tech.ydb.table.values.StructType.of(hashMap2);
        ValueReader[] valueReaderArr = new ValueReader[of.getMembersCount()];
        for (int i4 = 0; i4 < of.getMembersCount(); i4++) {
            valueReaderArr[i4] = (ValueReader) hashMap3.get(of.getMemberName(i4));
        }
        if (this.method == IngestMethod.BULK_UPSERT) {
            return new YdbDataWriter(this.types, of, valueReaderArr, this.maxBatchSize) { // from class: tech.ydb.spark.connector.write.YdbWriterFactory.1
                @Override // tech.ydb.spark.connector.write.YdbDataWriter
                CompletableFuture<Status> executeWrite(ListValue listValue) {
                    return YdbWriterFactory.this.table.getCtx().getExecutor().executeBulkUpsert(YdbWriterFactory.this.table.getTablePath(), listValue);
                }
            };
        }
        final String makeBatchSql = makeBatchSql(this.method.name(), this.table.getTablePath(), of);
        return new YdbDataWriter(this.types, of, valueReaderArr, this.maxBatchSize) { // from class: tech.ydb.spark.connector.write.YdbWriterFactory.2
            @Override // tech.ydb.spark.connector.write.YdbDataWriter
            CompletableFuture<Status> executeWrite(ListValue listValue) {
                return YdbWriterFactory.this.table.getCtx().getExecutor().executeDataQuery(makeBatchSql, Params.of("$input", listValue));
            }
        };
    }

    private static String makeBatchSql(String str, String str2, tech.ydb.table.values.StructType structType) {
        StringBuilder sb = new StringBuilder();
        sb.append("DECLARE $input AS List<Struct<");
        for (int i = 0; i < structType.getMembersCount(); i++) {
            sb.append('`').append(structType.getMemberName(i)).append('`');
            sb.append(": ").append(structType.getMemberType(i));
            if (i + 1 < structType.getMembersCount()) {
                sb.append(", ");
            }
        }
        sb.append(">>;\n");
        sb.append(str).append(" INTO `").append(str2).append("`  SELECT * FROM AS_TABLE($input);");
        return sb.toString();
    }
}
