package com.alibaba.hbase.client;

import com.alibaba.lindorm.client.TableService;
import com.alibaba.lindorm.client.core.LindormWideColumnService;
import com.alibaba.lindorm.client.core.utils.Bytes;
import com.alibaba.lindorm.client.dml.Aggregate;
import com.alibaba.lindorm.client.dml.ColumnValue;
import com.alibaba.lindorm.client.dml.Condition;
import com.alibaba.lindorm.client.dml.ConditionFactory;
import com.alibaba.lindorm.client.dml.Row;
import com.alibaba.lindorm.client.exception.LindormException;
import com.alibaba.lindorm.client.schema.DataType;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import com.google.protobuf.RpcCallback;
import com.google.protobuf.RpcController;
import org.apache.commons.lang.exception.ExceptionUtils;
import org.apache.hadoop.hbase.HConstants;
import org.apache.hadoop.hbase.KeyValue;
import org.apache.hadoop.hbase.client.Scan;
import org.apache.hadoop.hbase.coprocessor.ColumnInterpreter;
import org.apache.hadoop.hbase.protobuf.ProtobufUtil;
import org.apache.hadoop.hbase.protobuf.ResponseConverter;
import org.apache.hadoop.hbase.protobuf.generated.AggregateProtos;
import org.apache.hadoop.hbase.protobuf.generated.ClientProtos;

import java.io.IOException;
import java.util.NavigableSet;

import static com.alibaba.lindorm.client.dml.ConditionFactory.compare;

public class AliHBaseUEAggregateService<T, S, P extends Message, Q extends Message, R extends Message>
    extends AggregateProtos.AggregateService {

  private TableService tableService;

  private String table;

  private static byte[] PK_COLUMN_NAME = LindormWideColumnService.UNIFIED_PK_COLUMN_NAME.getBytes();

  public AliHBaseUEAggregateService(TableService tableService, String table) {
    this.tableService = tableService;
    this.table = table;
  }

  @Override
  public void getMax(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    AggregateProtos.AggregateResponse response = null;
    try {

      LindormAggregate lindormAggregate = toLindormAggregate(aggregateRequest);
      ColumnValue cv = lindormAggregate.max();

      ColumnInterpreter<T, S, P, Q, R> ci = constructColumnInterpreterFromRequest(aggregateRequest);
      byte[] value = cv.getBinary();
      T max = ci.getValue(lindormAggregate.getCfName(), lindormAggregate.getColName(),
          new KeyValue(HConstants.EMPTY_BYTE_ARRAY, lindormAggregate.getCfName(), lindormAggregate.getColName(),
              value));

      response = AggregateProtos.AggregateResponse.newBuilder().addFirstPart(ci.getProtoForCellType(max).toByteString())
          .build();

    } catch (IOException e) {
      ResponseConverter.setControllerException(rpcController, e);
    }
    if (rpcCallback != null) {
      rpcCallback.run(response);
    }
  }

  @Override
  public void getMin(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    AggregateProtos.AggregateResponse response = null;
    try {

      LindormAggregate lindormAggregate = toLindormAggregate(aggregateRequest);
      ColumnValue cv = lindormAggregate.min();

      ColumnInterpreter<T, S, P, Q, R> ci = constructColumnInterpreterFromRequest(aggregateRequest);
      byte[] value = cv.getBinary();
      T min = ci.getValue(lindormAggregate.getCfName(), lindormAggregate.getColName(),
          new KeyValue(HConstants.EMPTY_BYTE_ARRAY, lindormAggregate.getCfName(), lindormAggregate.getColName(),
              value));

      response = AggregateProtos.AggregateResponse.newBuilder().addFirstPart(ci.getProtoForCellType(min).toByteString())
          .build();

    } catch (IOException e) {
      ResponseConverter.setControllerException(rpcController, e);
    }
    if (rpcCallback != null) {
      rpcCallback.run(response);
    }

  }

  @Override
  public void getSum(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    AggregateProtos.AggregateResponse response = null;
    try {

      LindormAggregate lindormAggregate = toLindormAggregate(aggregateRequest);
      ColumnValue cv = lindormAggregate.sum();

      ColumnInterpreter<T, S, P, Q, R> ci = constructColumnInterpreterFromRequest(aggregateRequest);

      byte[] value = ElementConvertor.toValueBytes(cv, ci);
      T sum = ci.getValue(lindormAggregate.getCfName(), lindormAggregate.getColName(),
          new KeyValue(HConstants.EMPTY_BYTE_ARRAY, lindormAggregate.getCfName(), lindormAggregate.getColName(),
              value));

      response = AggregateProtos.AggregateResponse.newBuilder().addFirstPart(ci.getProtoForCellType(sum).toByteString())
          .build();

    } catch (IOException e) {
      ResponseConverter.setControllerException(rpcController, e);
    }
    if (rpcCallback != null) {
      rpcCallback.run(response);
    }
  }


  @Override
  public void getRowNum(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    AggregateProtos.AggregateResponse response = null;
    try {

      LindormAggregate lindormAggregate = toLindormAggregate(aggregateRequest);
      ColumnValue cv = lindormAggregate.count();

      response = AggregateProtos.AggregateResponse.newBuilder()
          .addFirstPart(ByteString.copyFrom(Bytes.toBytes(cv.getLong().longValue()))).build();

    } catch (IOException e) {
      ExceptionUtils.getFullStackTrace(e);
      ResponseConverter.setControllerException(rpcController, e);
    }
    if (rpcCallback != null) {
      rpcCallback.run(response);
    }
  }

  @Override
  public void getAvg(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {

    AggregateProtos.AggregateResponse response = null;
    try {

      LindormAggregate lindormAggregate = toLindormAggregate(aggregateRequest);
      ColumnValue cv = lindormAggregate.avg();

      ColumnInterpreter<T, S, P, Q, R> ci = constructColumnInterpreterFromRequest(aggregateRequest);

      byte[] value = ElementConvertor.toValueBytes(cv, ci);
      T avg = ci.getValue(lindormAggregate.getCfName(), lindormAggregate.getColName(),
          new KeyValue(HConstants.EMPTY_BYTE_ARRAY, lindormAggregate.getCfName(), lindormAggregate.getColName(),
              value));

      AggregateProtos.AggregateResponse.Builder pair = AggregateProtos.AggregateResponse.newBuilder();
      pair.addFirstPart(ci.getProtoForCellType(avg).toByteString());
      pair.setSecondPart(ByteString.copyFrom(Bytes.toBytes(1L)));
      response = pair.build();

    } catch (IOException e) {
      ResponseConverter.setControllerException(rpcController, e);
    }
    if (rpcCallback != null) {
      rpcCallback.run(response);
    }
  }

  @Override
  public void getStd(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    throw new UnsupportedOperationException("GetStd unsupported !");
  }

  @Override
  public void getMedian(RpcController rpcController, AggregateProtos.AggregateRequest aggregateRequest,
      RpcCallback<AggregateProtos.AggregateResponse> rpcCallback) {
    throw new UnsupportedOperationException("GetMedian unsupported !");
  }


  private LindormAggregate toLindormAggregate(AggregateProtos.AggregateRequest aggregateRequest) throws IOException {
    Condition condition = convertRequestToCondition(aggregateRequest);

    Aggregate aggregate = tableService.aggregate().from(table);
    if (condition != null) {
      aggregate.where(condition);
    }
    if (!aggregateRequest.hasScan()) {
      throw new UnsupportedOperationException("Scan is null !");
    }
    ClientProtos.Scan scan = aggregateRequest.getScan();
    Scan scanner = ProtobufUtil.toScan(scan);
    byte[] colFamily = scanner.getFamilies()[0];
    NavigableSet<byte[]> qualifiers = scanner.getFamilyMap().get(colFamily);
    byte[] qualifier = null;
    if (qualifiers != null && !qualifiers.isEmpty()) {
      qualifier = qualifiers.pollFirst();
    }

    if (!aggregateRequest.hasInterpreterClassName()) {
      throw new UnsupportedOperationException("Must provide interpreter class");
    }
    DataType interpreterDatype = ElementConvertor.toInterpreterDataType(aggregateRequest.getInterpreterClassName());

    return new LindormAggregate(colFamily, qualifier, interpreterDatype, aggregate);
  }



  public Condition convertRequestToCondition(AggregateProtos.AggregateRequest aggregateRequest)
      throws LindormException {
    checkAggregateSupport(aggregateRequest);
    Condition startRow = null;
    Condition stopRow = null;
    if (aggregateRequest.hasScan()) {
      ClientProtos.Scan scan = aggregateRequest.getScan();
      if (scan.hasStartRow()) {
        startRow = compare(PK_COLUMN_NAME, ConditionFactory.CompareOp.GREATER, scan.getStartRow().toByteArray());
      }
      if (scan.hasStopRow()) {
        stopRow = compare(PK_COLUMN_NAME, ConditionFactory.CompareOp.LESS_OR_EQUAL, scan.getStopRow().toByteArray());
      }
    }
    Condition condition = null;
    if (startRow != null) {
      condition = ConditionFactory.and(startRow);
    }
    if (stopRow != null) {
      condition = ConditionFactory.and(startRow);
    }
    return condition;
  }

  public static void checkAggregateSupport(AggregateProtos.AggregateRequest aggregateRequest) {
    if (aggregateRequest.hasScan()) {
      ClientProtos.Scan scan = aggregateRequest.getScan();
      if (scan.hasFilter()) {
        throw new UnsupportedOperationException("Filter unsupported !");
      }
      if (scan.hasMaxVersions() && scan.getMaxVersions() > 1) {
        throw new UnsupportedOperationException("Versions unsupported ! current : " + scan.getMaxVersions());
      }
    }
  }

  ColumnInterpreter<T, S, P, Q, R> constructColumnInterpreterFromRequest(AggregateProtos.AggregateRequest request)
      throws IOException {
    String className = request.getInterpreterClassName();
    Class<?> cls;
    try {
      cls = Class.forName(className);
      ColumnInterpreter<T, S, P, Q, R> ci = (ColumnInterpreter<T, S, P, Q, R>) cls.newInstance();
      if (request.hasInterpreterSpecificBytes()) {
        ByteString b = request.getInterpreterSpecificBytes();
        P initMsg = ProtobufUtil.getParsedGenericInstance(ci.getClass(), 2, b);
        ci.initialize(initMsg);
      }
      return ci;
    } catch (Exception e) {
      throw new IOException(e);
    }
  }

  class LindormAggregate {

    private byte[] cfName;

    private byte[] colName;

    private DataType interpreterDatype;

    private Aggregate aggregate;

    public LindormAggregate(byte[] cfName, byte[] colName, DataType interpreterDatype, Aggregate aggregate) {
      this.cfName = cfName;
      this.colName = colName;
      this.interpreterDatype = interpreterDatype;
      this.aggregate = aggregate;
    }


    public byte[] getCfName() {
      return cfName;
    }

    public byte[] getColName() {
      return colName;
    }

    public DataType getInterpreterDatype() {
      return interpreterDatype;
    }

    public Aggregate getAggregate() {
      return aggregate;
    }

    public ColumnValue sum() throws LindormException {
      String colStrName = Bytes.toString(colName);
      aggregate.sumAs(Bytes.toString(cfName), colStrName, colStrName, interpreterDatype);
      Row row = aggregate.execute();
      return row.getColumnValue(colName);
    }


    public ColumnValue max() throws LindormException {
      String colStrName = Bytes.toString(colName);
      aggregate.maxAs(Bytes.toString(cfName), colStrName, colStrName);
      Row row = aggregate.execute();
      return row.getColumnValue(colName);
    }

    public ColumnValue min() throws LindormException {
      String colStrName = Bytes.toString(colName);
      aggregate.minAs(Bytes.toString(cfName), colStrName, colStrName);
      Row row = aggregate.execute();
      return row.getColumnValue(colName);
    }

    public ColumnValue avg() throws LindormException {
      String colStrName = Bytes.toString(colName);
      aggregate.avgAs(Bytes.toString(cfName), colStrName, colStrName, interpreterDatype);
      Row row = aggregate.execute();
      return row.getColumnValue(colName);
    }

    public ColumnValue count() throws LindormException {
      String colStrName = Bytes.toString(colName);
      aggregate.countAs(Bytes.toString(cfName), colStrName, colStrName);
      Row row = aggregate.execute();
      return row.getColumnValue(colName);
    }
  }

}
