Fix handling of varchar constants in MOT JIT

This commit is contained in:
Vinoth Veeraraghavan 2022-03-21 10:12:22 +08:00
parent 3a552c955d
commit 79516bb52d
14 changed files with 606 additions and 5 deletions

View File

@ -242,6 +242,40 @@ extern bool IsTypeSupported(int resultType)
}
}
extern bool IsStringType(int type)
{
switch (type) {
case VARCHAROID:
case BPCHAROID:
case TEXTOID:
case BYTEAOID:
return true;
default:
return false;
}
}
extern bool IsPrimitiveType(int type)
{
switch (type) {
case BOOLOID:
case CHAROID:
case INT1OID:
case INT2OID:
case INT4OID:
case INT8OID:
case TIMEOID:
case TIMESTAMPOID:
case DATEOID:
case FLOAT4OID:
case FLOAT8OID:
return true;
default:
return false;
}
}
static bool IsEqualsWhereOperator(int whereOp)
{
bool result = false;
@ -623,4 +657,234 @@ extern bool PrepareSubQueryData(JitContext* jitContext, JitCompoundPlan* plan)
return result;
}
static bool CloneStringDatum(Datum source, Datum* target, JitContextUsage usage)
{
bytea* value = DatumGetByteaP(source);
size_t len = VARSIZE(value); // includes header len VARHDRSZ
char* src = VARDATA(value);
// special case: empty string
if (len == 0) {
len = VARHDRSZ;
}
size_t strSize = len - VARHDRSZ;
MOT_LOG_TRACE("CloneStringDatum(): len = %u, src = %*.*s", (unsigned)len, strSize, strSize, src);
bytea* copy = nullptr;
if (usage == JIT_CONTEXT_GLOBAL) {
copy = (bytea*)MOT::MemGlobalAlloc(len);
} else {
copy = (bytea*)MOT::MemSessionAlloc(len);
}
if (copy == nullptr) {
MOT_REPORT_ERROR(
MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for datum string constant", (unsigned)len);
return false;
}
if (strSize > 0) {
errno_t erc = memcpy_s(VARDATA(copy), strSize, (uint8_t*)src, strSize);
securec_check(erc, "\0", "\0");
}
SET_VARSIZE(copy, len);
*target = PointerGetDatum(copy);
return true;
}
static bool CloneTimeTzDatum(Datum source, Datum* target, JitContextUsage usage)
{
MOT::TimetzSt* value = (MOT::TimetzSt*)DatumGetPointer(source);
size_t allocSize = sizeof(MOT::TimetzSt);
MOT::TimetzSt* copy = nullptr;
if (usage == JIT_CONTEXT_GLOBAL) {
copy = (MOT::TimetzSt*)MOT::MemGlobalAlloc(allocSize);
} else {
copy = (MOT::TimetzSt*)MOT::MemSessionAlloc(allocSize);
}
if (copy == nullptr) {
MOT_REPORT_ERROR(
MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for datum TimeTZ constant", (unsigned)allocSize);
return false;
}
copy->m_time = value->m_time;
copy->m_zone = value->m_zone;
*target = PointerGetDatum(copy);
return true;
}
static bool CloneIntervalDatum(Datum source, Datum* target, JitContextUsage usage)
{
MOT::IntervalSt* value = (MOT::IntervalSt*)DatumGetPointer(source);
size_t allocSize = sizeof(MOT::IntervalSt);
MOT::IntervalSt* copy = nullptr;
if (usage == JIT_CONTEXT_GLOBAL) {
copy = (MOT::IntervalSt*)MOT::MemGlobalAlloc(allocSize);
} else {
copy = (MOT::IntervalSt*)MOT::MemSessionAlloc(allocSize);
}
if (copy == nullptr) {
MOT_REPORT_ERROR(MOT_ERROR_OOM,
"JIT Compile",
"Failed to allocate %u bytes for datum Interval constant",
(unsigned)allocSize);
return false;
}
copy->m_day = value->m_day;
copy->m_month = value->m_month;
copy->m_time = value->m_time;
*target = PointerGetDatum(copy);
return true;
}
static bool CloneTIntervalDatum(Datum source, Datum* target, JitContextUsage usage)
{
MOT::TintervalSt* value = (MOT::TintervalSt*)DatumGetPointer(source);
size_t allocSize = sizeof(MOT::TintervalSt);
MOT::TintervalSt* copy = nullptr;
if (usage == JIT_CONTEXT_GLOBAL) {
copy = (MOT::TintervalSt*)MOT::MemGlobalAlloc(allocSize);
} else {
copy = (MOT::TintervalSt*)MOT::MemSessionAlloc(allocSize);
}
if (copy == nullptr) {
MOT_REPORT_ERROR(MOT_ERROR_OOM,
"JIT Compile",
"Failed to allocate %u bytes for datum TInterval constant",
(unsigned)allocSize);
return false;
}
copy->m_status = value->m_status;
copy->m_data[0] = value->m_data[0];
copy->m_data[1] = value->m_data[1];
*target = PointerGetDatum(copy);
return true;
}
static bool CloneNumericDatum(Datum source, Datum* target, JitContextUsage usage)
{
varlena* var = (varlena*)DatumGetPointer(source);
Size len = VARSIZE(var);
struct varlena* result = nullptr;
if (usage == JIT_CONTEXT_GLOBAL) {
result = (varlena*)MOT::MemGlobalAlloc(len);
} else {
result = (varlena*)MOT::MemSessionAlloc(len);
}
if (result == nullptr) {
MOT_REPORT_ERROR(
MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for datum Numeric constant", (unsigned)len);
return false;
}
errno_t rc = memcpy_s(result, len, var, len);
securec_check(rc, "\0", "\0");
*target = NumericGetDatum((Numeric)result);
return true;
}
static bool CloneCStringDatum(Datum source, Datum* target, JitContextUsage usage)
{
char* src = DatumGetCString(source);
size_t len = strlen(src) + 1; // includes terminating null
char* copy = nullptr;
if (usage == JIT_CONTEXT_GLOBAL) {
copy = (char*)MOT::MemGlobalAlloc(len);
} else {
copy = (char*)MOT::MemSessionAlloc(len);
}
if (copy == nullptr) {
MOT_REPORT_ERROR(
MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for datum string constant", (unsigned)len);
return false;
}
errno_t erc = memcpy_s(copy, len, src, len);
securec_check(erc, "\0", "\0");
*target = PointerGetDatum(copy);
return true;
}
extern bool CloneDatum(Datum source, int type, Datum* target, JitContextUsage usage)
{
bool result = true;
if (IsStringType(type)) {
result = CloneStringDatum(source, target, usage);
} else {
switch (type) {
case TIMETZOID:
result = CloneTimeTzDatum(source, target, usage);
break;
case INTERVALOID:
result = CloneIntervalDatum(source, target, usage);
break;
case TINTERVALOID:
result = CloneTIntervalDatum(source, target, usage);
break;
case NUMERICOID:
result = CloneNumericDatum(source, target, usage);
break;
case UNKNOWNOID:
result = CloneCStringDatum(source, target, usage);
break;
default:
MOT_LOG_TRACE("Unsupported non-primitive constant type: %d", type);
result = false;
break;
}
}
return result;
}
extern bool PrepareDatumArray(Const* constArray, uint32_t constCount, JitDatumArray* datumArray)
{
size_t allocSize = sizeof(JitDatum) * constCount;
JitDatum* datums = (JitDatum*)MOT::MemGlobalAlloc(allocSize);
if (datums == nullptr) {
MOT_REPORT_ERROR(
MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for constant datum array", (unsigned)allocSize);
return false;
}
for (uint32_t i = 0; i < constCount; ++i) {
Const* constValue = &constArray[i];
datums[i].m_isNull = constValue->constisnull;
datums[i].m_type = constValue->consttype;
if (!datums[i].m_isNull) {
if (IsPrimitiveType(constValue->constvalue)) {
datums[i].m_datum = constValue->constvalue;
} else {
if (!CloneDatum(
constValue->constvalue, constValue->consttype, &datums[i].m_datum, JIT_CONTEXT_GLOBAL)) {
MOT_LOG_TRACE("Failed to prepare datum value");
for (uint32_t j = 0; j < i; ++j) {
if (!datums[j].m_isNull && !IsPrimitiveType(datums[j].m_type)) {
MOT::MemGlobalFree(DatumGetPointer(datums[j].m_datum));
}
}
MOT::MemGlobalFree(datums);
return false;
}
}
}
}
datumArray->m_datumCount = constCount;
datumArray->m_datums = datums;
return true;
}
} // namespace JitExec

View File

@ -31,6 +31,9 @@
#include "mot_engine.h"
/** @define The maximum number of constant objects that can be used in a query. */
#define MOT_JIT_MAX_CONST 1024
// This file contains definitions used both by LLVM and TVM jitted code
namespace JitExec {
// forward declaration
@ -64,6 +67,12 @@ extern int BuildIndexColumnOffsets(MOT::Table* table, const MOT::Index* index, i
/** @brief Queries whether a PG type is supported by MOT tables. */
extern bool IsTypeSupported(int resultType);
/** @brief Queries whether a PG type represents a string. */
extern bool IsStringType(int type);
/** @brief Queries whether a PG type represents a primitive type. */
extern bool IsPrimitiveType(int type);
/** @brief Queries whether a WHERE clause operator is supported. */
extern bool IsWhereOperatorSupported(int whereOp);
@ -125,6 +134,12 @@ extern void DestroyTableInfo(TableInfo* table_info);
* @return True if operations succeeded, otherwise false.
*/
extern bool PrepareSubQueryData(JitContext* jitContext, JitCompoundPlan* plan);
/** @brief Prepares array of global datum objects from array of constants. */
extern bool PrepareDatumArray(Const* constArray, uint32_t constCount, JitDatumArray* datumArray);
/** @brief Clones an interval datum into global memory. */
extern bool CloneDatum(Datum source, int type, Datum* target, JitContextUsage usage);
} // namespace JitExec
#endif

View File

@ -97,6 +97,63 @@ extern void FreeJitContext(JitContext* jitContext)
}
}
static bool CloneDatumArray(JitDatumArray* source, JitDatumArray* target, JitContextUsage usage)
{
uint32_t datumCount = source->m_datumCount;
if (datumCount == 0) {
target->m_datumCount = 0;
target->m_datums = nullptr;
return true;
}
size_t allocSize = sizeof(JitDatum) * datumCount;
JitDatum* datumArray = nullptr;
if (usage == JIT_CONTEXT_GLOBAL) {
datumArray = (JitDatum*)MOT::MemGlobalAlloc(allocSize);
} else {
datumArray = (JitDatum*)MOT::MemSessionAlloc(allocSize);
}
if (datumArray == nullptr) {
MOT_REPORT_ERROR(
MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for datum array", (unsigned)allocSize);
return false;
}
for (uint32_t i = 0; i < datumCount; ++i) {
JitDatum* datum = (JitDatum*)&source->m_datums[i];
datumArray[i].m_isNull = datum->m_isNull;
datumArray[i].m_type = datum->m_type;
if (!datum->m_isNull) {
if (IsPrimitiveType(datum->m_type)) {
datumArray[i].m_datum = datum->m_datum;
} else {
if (!CloneDatum(datum->m_datum, datum->m_type, &datumArray[i].m_datum, usage)) {
MOT_REPORT_ERROR(MOT_ERROR_OOM, "JIT Compile", "Failed to clone datum array entry");
for (uint32_t j = 0; j < i; ++j) {
if (!IsPrimitiveType(datumArray[j].m_type)) {
if (usage == JIT_CONTEXT_GLOBAL) {
MOT::MemGlobalFree(DatumGetPointer(datumArray[j].m_datum));
} else {
MOT::MemGlobalFree(DatumGetPointer(datumArray[j].m_datum));
}
}
}
if (usage == JIT_CONTEXT_GLOBAL) {
MOT::MemGlobalFree(datumArray);
} else {
MOT::MemSessionFree(datumArray);
}
return false;
}
}
}
}
target->m_datums = datumArray;
target->m_datumCount = datumCount;
return true;
}
extern JitContext* CloneJitContext(JitContext* sourceJitContext)
{
MOT_LOG_TRACE("Cloning JIT context %p of query: %s", sourceJitContext, sourceJitContext->m_queryString);
@ -109,6 +166,11 @@ extern JitContext* CloneJitContext(JitContext* sourceJitContext)
result->m_llvmFunction = sourceJitContext->m_llvmFunction;
result->m_tvmFunction = sourceJitContext->m_tvmFunction;
result->m_commandType = sourceJitContext->m_commandType;
if (!CloneDatumArray(&sourceJitContext->m_constDatums, &result->m_constDatums, JIT_CONTEXT_LOCAL)) {
MOT_REPORT_ERROR(MOT_ERROR_OOM, "JIT Compile", "Failed to clone constant datum array");
DestroyJitContext(result);
return nullptr;
}
result->m_table = sourceJitContext->m_table;
result->m_index = sourceJitContext->m_index;
result->m_indexId = sourceJitContext->m_indexId;
@ -492,6 +554,29 @@ extern bool PrepareJitContext(JitContext* jitContext)
return true;
}
static void DestroyDatumArray(JitDatumArray* datumArray, JitContextUsage usage)
{
if (datumArray->m_datumCount > 0) {
MOT_ASSERT(datumArray->m_datums != nullptr);
for (uint32_t i = 0; i < datumArray->m_datumCount; ++i) {
if (!datumArray->m_datums[i].m_isNull && !IsPrimitiveType(datumArray->m_datums[i].m_type)) {
if (usage == JIT_CONTEXT_GLOBAL) {
MOT::MemGlobalFree(DatumGetPointer(datumArray->m_datums[i].m_datum));
} else {
MOT::MemSessionFree(DatumGetPointer(datumArray->m_datums[i].m_datum));
}
}
}
if (usage == JIT_CONTEXT_GLOBAL) {
MOT::MemGlobalFree(datumArray->m_datums);
} else {
MOT::MemSessionFree(datumArray->m_datums);
}
datumArray->m_datums = nullptr;
datumArray->m_datumCount = 0;
}
}
extern void DestroyJitContext(JitContext* jitContext)
{
if (jitContext != nullptr) {
@ -513,6 +598,9 @@ extern void DestroyJitContext(JitContext* jitContext)
jitContext->m_jitSource = nullptr;
}
// cleanup constant datum array
DestroyDatumArray(&jitContext->m_constDatums, jitContext->m_usage);
// cleanup sub-query data array
CleanupJitContextSubQueryDataArray(jitContext);

View File

@ -34,6 +34,26 @@ namespace JitExec {
struct JitContextPool;
struct JitSource;
/** @struct Array of constant datum objects used in JIT execution. */
struct JitDatum {
/** @var The constant value. */
Datum m_datum;
/** @var The constant type. */
int m_type;
/** @var The constant is-null property. */
int m_isNull;
};
struct JitDatumArray {
/** @var The number of constant datum objects used by the jitted function (global copy for all contexts). */
uint64_t m_datumCount;
/** @var The array of constant datum objects used by the jitted function (global copy for all contexts). */
JitDatum* m_datums;
};
/**
* @typedef The context for executing a jitted function.
*/
@ -97,6 +117,9 @@ struct JitContext {
/** @var The source query string. */
const char* m_queryString; // L1 offset 40 (constant)
/** @var The array of constant datum objects used by the jitted function (global copy for all contexts). */
JitDatumArray m_constDatums;
/*---------------------- Range Scan execution state -------------------*/
/** @var Begin iterator for range select (stateful execution). */
MOT::IndexIterator* m_beginIterator; // L1 offset 48

View File

@ -392,6 +392,22 @@ int getExprArgIsNull(int arg_pos)
return result;
}
Datum GetConstAt(int constId, int argPos)
{
MOT_LOG_DEBUG("Retrieving constant datum by id %d", constId);
Datum result = PointerGetDatum(nullptr);
JitExec::JitContext* ctx = u_sess->mot_cxt.jit_context;
if (constId < (int)ctx->m_constDatums.m_datumCount) {
JitExec::JitDatum* datum = &ctx->m_constDatums.m_datums[constId];
result = datum->m_datum;
setExprArgIsNull(argPos, datum->m_isNull);
DBG_PRINT_DATUM("Retrieved constant datum", datum->m_type, datum->m_datum, datum->m_isNull);
} else {
MOT_LOG_ERROR("Invalid constant identifier: %d", constId);
}
return result;
}
Datum getDatumParam(ParamListInfo params, int paramid, int arg_pos)
{
MOT_LOG_DEBUG("Retrieving datum param at index %d", paramid);

View File

@ -96,6 +96,14 @@ void setExprArgIsNull(int arg_pos, int isnull);
*/
int getExprArgIsNull(int arg_pos);
/**
* @brief Retrieves a pooled constant by its identifier.
* @param constId The identifier of the constant value.
* @param argPos The ordinal position of the enveloping parameter expression.
* @return The constant value.
*/
Datum GetConstAt(int constId, int argPos);
/**
* @brief Retrieves a datum parameter from parameters array.
* @param params The parameter array.

View File

@ -454,10 +454,21 @@ static llvm::Value* ProcessConstExpr(
if (IsTypeSupported(const_value->consttype)) {
result_type = const_value->consttype;
AddSetExprArgIsNull(ctx, arg_pos, const_value->constisnull); // mark expression null status
result = llvm::ConstantInt::get(ctx->INT64_T, const_value->constvalue, true);
if (IsPrimitiveType(result_type)) {
result = llvm::ConstantInt::get(ctx->INT64_T, const_value->constvalue, true);
} else {
int constId = AllocateConstId(ctx, result_type, const_value->constvalue, const_value->constisnull);
if (constId == -1) {
MOT_LOG_TRACE("Failed to allocate constant identifier");
} else {
result = AddGetConstAt(ctx, constId, arg_pos);
}
}
if (max_arg && (arg_pos > *max_arg)) {
*max_arg = arg_pos;
}
} else {
MOT_LOG_TRACE("Failed to process const expression: type %d unsupported", (int)result_type);
}
MOT_LOG_DEBUG("%*s <-- Processing CONST expression result: %p", depth, "", result);
@ -769,8 +780,18 @@ static llvm::Value* ProcessExpr(
static llvm::Value* ProcessConstExpr(JitLlvmCodeGenContext* ctx, const JitConstExpr* expr, int* max_arg)
{
llvm::Value* result = nullptr;
AddSetExprArgIsNull(ctx, expr->_arg_pos, expr->_is_null); // mark expression null status
llvm::Value* result = llvm::ConstantInt::get(ctx->INT64_T, expr->_value, true);
if (IsPrimitiveType(expr->_const_type)) {
result = llvm::ConstantInt::get(ctx->INT64_T, expr->_value, true);
} else {
int constId = AllocateConstId(ctx, expr->_const_type, expr->_value, expr->_is_null);
if (constId == -1) {
MOT_LOG_TRACE("Failed to allocate constant identifier");
} else {
result = AddGetConstAt(ctx, constId, expr->_arg_pos);
}
}
if (max_arg && (expr->_arg_pos > *max_arg)) {
*max_arg = expr->_arg_pos;
}

View File

@ -632,6 +632,11 @@ inline void DefineGetSubQueryEndIteratorKey(JitLlvmCodeGenContext* ctx, llvm::Mo
defineFunction(module, ctx->KeyType->getPointerTo(), "GetSubQueryEndIteratorKey", ctx->INT32_T, nullptr);
}
inline void DefineGetConstAt(JitLlvmCodeGenContext* ctx, llvm::Module* module)
{
ctx->GetConstAtFunc = defineFunction(module, ctx->DATUM_T, "GetConstAt", ctx->INT32_T, ctx->INT32_T, nullptr);
}
/*--------------------------- End of LLVM Helper Prototypes ---------------------------*/
/*--------------------------- Helpers to generate calls to Helper function via LLVM ---------------------------*/
@ -1326,6 +1331,13 @@ inline llvm::Value* AddGetSubQueryEndIteratorKey(JitLlvmCodeGenContext* ctx, int
return AddFunctionCall(ctx, ctx->GetSubQueryEndIteratorKeyFunc, subQueryIndexValue, nullptr);
}
inline llvm::Value* AddGetConstAt(JitLlvmCodeGenContext* ctx, int constId, int argPos)
{
llvm::ConstantInt* constIdValue = llvm::ConstantInt::get(ctx->INT32_T, constId, true);
llvm::ConstantInt* argPosValue = llvm::ConstantInt::get(ctx->INT32_T, argPos, true);
return AddFunctionCall(ctx, ctx->GetConstAtFunc, constIdValue, argPosValue, nullptr);
}
/** @brief Adds a call to issueDebugLog(function, msg). */
#ifdef MOT_JIT_DEBUG
inline void IssueDebugLogImpl(JitLlvmCodeGenContext* ctx, const char* function, const char* msg)

View File

@ -171,6 +171,7 @@ struct JitLlvmCodeGenContext {
llvm::FunctionCallee GetSubQueryIndexFunc;
llvm::FunctionCallee GetSubQuerySearchKeyFunc;
llvm::FunctionCallee GetSubQueryEndIteratorKeyFunc;
llvm::FunctionCallee GetConstAtFunc;
// builtins
#define APPLY_UNARY_OPERATOR(funcid, name) llvm::FunctionCallee _builtin_##name;
@ -224,10 +225,16 @@ struct JitLlvmCodeGenContext {
TableInfo _inner_table_info;
TableInfo* m_subQueryTableInfo;
// non-primitive constants
uint32_t m_constCount;
Const* m_constValues;
dorado::GsCodeGen* _code_gen;
dorado::GsCodeGen::LlvmBuilder* _builder;
llvm::Function* m_jittedQuery;
};
extern int AllocateConstId(JitLlvmCodeGenContext* ctx, int type, Datum value, bool isNull);
} // namespace JitExec
#endif /* JIT_LLVM_QUERY_H */

View File

@ -142,6 +142,7 @@ void InitCodeGenContextFuncs(JitLlvmCodeGenContext* ctx)
DefineGetSubQueryIndex(ctx, module);
DefineGetSubQuerySearchKey(ctx, module);
DefineGetSubQueryEndIteratorKey(ctx, module);
DefineGetConstAt(ctx, module);
}
#define APPLY_UNARY_OPERATOR(funcid, name) \
@ -258,6 +259,18 @@ static bool InitCodeGenContext(JitLlvmCodeGenContext* ctx, GsCodeGen* code_gen,
return false;
}
ctx->m_constCount = 0;
size_t allocSize = sizeof(Const) * MOT_JIT_MAX_CONST;
ctx->m_constValues = (Const*)MOT::MemSessionAlloc(allocSize);
if (ctx->m_constValues == nullptr) {
MOT_REPORT_ERROR(MOT_ERROR_OOM,
"JIT Compile",
"Failed to allocate %u bytes for constant array in code-generation context",
allocSize);
DestroyCodeGenContext(ctx);
return false;
}
InitCodeGenContextTypes(ctx);
InitCodeGenContextFuncs(ctx);
InitCodeGenContextBuiltins(ctx);
@ -351,6 +364,9 @@ static void DestroyCodeGenContext(JitLlvmCodeGenContext* ctx)
MOT::MemSessionFree(ctx->m_subQueryTableInfo);
ctx->m_subQueryTableInfo = nullptr;
}
if (ctx->m_constValues != nullptr) {
MOT::MemSessionFree(ctx->m_constValues);
}
if (ctx->_code_gen != nullptr) {
ctx->_code_gen->releaseResource();
delete ctx->_code_gen;
@ -358,6 +374,24 @@ static void DestroyCodeGenContext(JitLlvmCodeGenContext* ctx)
}
}
extern int AllocateConstId(JitLlvmCodeGenContext* ctx, int type, Datum value, bool isNull)
{
int res = -1;
if (ctx->m_constCount == MOT_JIT_MAX_CONST) {
MOT_REPORT_ERROR(MOT_ERROR_RESOURCE_LIMIT,
"JIT Compile",
"Cannot allocate constant identifier, reached limit of %u",
ctx->m_constCount);
} else {
res = ctx->m_constCount++;
ctx->m_constValues[res].consttype = type;
ctx->m_constValues[res].constvalue = value;
ctx->m_constValues[res].constisnull = isNull;
MOT_LOG_TRACE("Allocated constant id: %d", res);
}
return res;
}
/** @brief Wraps up an LLVM function (compiles it and prepares a funciton pointer). */
static JitContext* FinalizeCodegen(JitLlvmCodeGenContext* ctx, int max_arg, JitCommandType command_type)
{
@ -381,6 +415,15 @@ static JitContext* FinalizeCodegen(JitLlvmCodeGenContext* ctx, int max_arg, JitC
#endif
}
// prepare global constant array
JitDatumArray datumArray = {};
if (ctx->m_constCount > 0) {
if (!PrepareDatumArray(ctx->m_constValues, ctx->m_constCount, &datumArray)) {
MOT_LOG_ERROR("Failed to generate jitted code for query: Failed to prepare constant datum array");
return nullptr;
}
}
// that's it, we are ready
JitContext* jit_context = AllocJitContext(JIT_CONTEXT_GLOBAL);
if (jit_context == nullptr) {
@ -412,6 +455,8 @@ static JitContext* FinalizeCodegen(JitLlvmCodeGenContext* ctx, int max_arg, JitC
MOT_LOG_TRACE("Installed inner index id: %" PRIu64, jit_context->m_innerIndexId);
}
jit_context->m_commandType = command_type;
jit_context->m_constDatums.m_datumCount = datumArray.m_datumCount;
jit_context->m_constDatums.m_datums = datumArray.m_datums;
return jit_context;
}

View File

@ -424,7 +424,17 @@ static Expression* ProcessConstExpr(
if (IsTypeSupported(const_value->consttype)) {
result_type = const_value->consttype;
result = new (std::nothrow) ConstExpression(const_value->constvalue, arg_pos, (int)(const_value->constisnull));
if (IsPrimitiveType(result_type)) {
result =
new (std::nothrow) ConstExpression(const_value->constvalue, arg_pos, (int)(const_value->constisnull));
} else {
int constId = AllocateConstId(ctx, result_type, const_value->constvalue, const_value->constisnull);
if (constId == -1) {
MOT_LOG_TRACE("Failed to allocate constant identifier");
} else {
result = AddGetConstAt(ctx, constId, arg_pos);
}
}
if (max_arg && (arg_pos > *max_arg)) {
*max_arg = arg_pos;
}
@ -662,8 +672,18 @@ static Expression* ProcessExpr(
static Expression* ProcessConstExpr(JitTvmCodeGenContext* ctx, const JitConstExpr* expr, int* max_arg)
{
AddSetExprArgIsNull(ctx, expr->_arg_pos, (expr->_is_null ? 1 : 0)); // mark expression null status
Expression* result = new (std::nothrow) ConstExpression(expr->_value, expr->_arg_pos, (int)(expr->_is_null));
Expression* result = nullptr;
AddSetExprArgIsNull(ctx, expr->_arg_pos, expr->_is_null); // mark expression null status
if (IsPrimitiveType(expr->_const_type)) {
result = new (std::nothrow) ConstExpression(expr->_value, expr->_arg_pos, (int)(expr->_is_null));
} else {
int constId = AllocateConstId(ctx, expr->_const_type, expr->_value, expr->_is_null);
if (constId == -1) {
MOT_LOG_TRACE("Failed to allocate constant identifier");
} else {
result = AddGetConstAt(ctx, constId, expr->_arg_pos);
}
}
if (max_arg && (expr->_arg_pos > *max_arg)) {
*max_arg = expr->_arg_pos;
}

View File

@ -2761,6 +2761,31 @@ private:
int m_subQueryIndex;
};
/** @class GetConstAtExpression */
class GetConstAtExpression : public tvm::Expression {
public:
explicit GetConstAtExpression(int constId, int argPos)
: Expression(tvm::Expression::CanFail), m_constId(constId), m_argPos(argPos)
{}
~GetConstAtExpression() final
{}
Datum eval(tvm::ExecContext* exec_context) final
{
return (uint64_t)GetConstAt(m_constId, m_argPos);
}
void dump() final
{
(void)fprintf(stderr, "GetConstAt(constId=%d, argPos=%d)", m_constId, m_argPos);
}
private:
int m_constId;
int m_argPos;
};
inline tvm::Instruction* AddIsSoftMemoryLimitReached(JitTvmCodeGenContext* ctx)
{
return ctx->_builder->addInstruction(new (std::nothrow) IsSoftMemoryLimitReachedInstruction());
@ -3148,6 +3173,11 @@ inline void AddCopyAggregateToSubQueryResult(JitTvmCodeGenContext* ctx, int subQ
ctx->_builder->addInstruction(new (std::nothrow) CopyAggregateToSubQueryResultInstruction(subQueryIndex));
}
inline tvm::Expression* AddGetConstAt(JitTvmCodeGenContext* ctx, int constId, int argPos)
{
return new (std::nothrow) GetConstAtExpression(constId, argPos);
}
#ifdef MOT_JIT_DEBUG
inline void IssueDebugLogImpl(JitTvmCodeGenContext* ctx, const char* function, const char* msg)
{

View File

@ -61,7 +61,13 @@ struct JitTvmCodeGenContext {
/** @var The resulting jitted function. */
tvm::Function* m_jittedQuery;
// non-primitive constants
uint32_t m_constCount;
Const* m_constValues;
};
extern int AllocateConstId(JitTvmCodeGenContext* ctx, int type, Datum value, bool isNull);
} // namespace JitExec
#endif /* JIT_TVM_QUERY_H */

View File

@ -39,6 +39,8 @@ using namespace tvm;
namespace JitExec {
DECLARE_LOGGER(JitTvmQueryCodegen, JitExec)
static void DestroyCodeGenContext(JitTvmCodeGenContext* ctx);
/** @brief Initializes a context for compilation. */
static bool InitCodeGenContext(JitTvmCodeGenContext* ctx, Builder* builder, MOT::Table* table, MOT::Index* index,
MOT::Table* inner_table = nullptr, MOT::Index* inner_index = nullptr)
@ -59,6 +61,18 @@ static bool InitCodeGenContext(JitTvmCodeGenContext* ctx, Builder* builder, MOT:
return false;
}
ctx->m_constCount = 0;
size_t allocSize = sizeof(Const) * MOT_JIT_MAX_CONST;
ctx->m_constValues = (Const*)MOT::MemSessionAlloc(allocSize);
if (ctx->m_constValues == nullptr) {
MOT_REPORT_ERROR(MOT_ERROR_OOM,
"JIT Compile",
"Failed to allocate %u bytes for constant array in code-generation context",
allocSize);
DestroyCodeGenContext(ctx);
return false;
}
return true;
}
@ -130,9 +144,29 @@ static void DestroyCodeGenContext(JitTvmCodeGenContext* ctx)
for (uint32_t i = 0; i < ctx->m_subQueryCount; ++i) {
DestroyTableInfo(&ctx->m_subQueryTableInfo[i]);
}
if (ctx->m_constValues != nullptr) {
MOT::MemSessionFree(ctx->m_constValues);
}
}
}
extern int AllocateConstId(JitTvmCodeGenContext* ctx, int type, Datum value, bool isNull)
{
int res = -1;
if (ctx->m_constCount == MOT_JIT_MAX_CONST) {
MOT_REPORT_ERROR(MOT_ERROR_RESOURCE_LIMIT,
"JIT Compile",
"Cannot allocate constant identifier, reached limit of %u",
ctx->m_constCount);
} else {
res = ctx->m_constCount++;
ctx->m_constValues[res].consttype = type;
ctx->m_constValues[res].constvalue = value;
ctx->m_constValues[res].constisnull = isNull;
}
return res;
}
static JitContext* FinalizeCodegen(JitTvmCodeGenContext* ctx, int max_arg, JitCommandType command_type)
{
// do minimal verification and wrap up
@ -155,6 +189,16 @@ static JitContext* FinalizeCodegen(JitTvmCodeGenContext* ctx, int max_arg, JitCo
return nullptr;
}
// prepare global constant array
JitDatumArray datumArray = {};
if (ctx->m_constCount > 0) {
if (!PrepareDatumArray(ctx->m_constValues, ctx->m_constCount, &datumArray)) {
MOT_LOG_ERROR("Failed to generate jitted code for query: Failed to prepare constant datum array");
delete ctx->m_jittedQuery;
return nullptr;
}
}
// that's it, we are ready
JitContext* jit_context = AllocJitContext(JIT_CONTEXT_GLOBAL);
if (jit_context == nullptr) {
@ -178,6 +222,8 @@ static JitContext* FinalizeCodegen(JitTvmCodeGenContext* ctx, int max_arg, JitCo
}
jit_context->m_commandType = command_type;
jit_context->m_subQueryCount = 0;
jit_context->m_constDatums.m_datumCount = datumArray.m_datumCount;
jit_context->m_constDatums.m_datums = datumArray.m_datums;
return jit_context;
}