/*
|
* Copyright 2018 Google Inc.
|
*
|
* Use of this source code is governed by a BSD-style license that can be
|
* found in the LICENSE file.
|
*/
|
|
#ifndef SKSL_STANDALONE
|
|
#ifdef SK_LLVM_AVAILABLE
|
|
#include "SkSLJIT.h"
|
|
#include "SkCpu.h"
|
#include "SkRasterPipeline.h"
|
#include "ir/SkSLAppendStage.h"
|
#include "ir/SkSLExpressionStatement.h"
|
#include "ir/SkSLFunctionCall.h"
|
#include "ir/SkSLFunctionReference.h"
|
#include "ir/SkSLIndexExpression.h"
|
#include "ir/SkSLProgram.h"
|
#include "ir/SkSLUnresolvedFunction.h"
|
#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
|
|
static constexpr int MAX_VECTOR_COUNT = 16;
|
|
extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) {
|
p->append((SkRasterPipeline::StockStage) stage, ctx);
|
}
|
|
#define PTR_SIZE sizeof(void*)
|
|
extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) {
|
p->append(fn, nullptr);
|
}
|
|
extern "C" void sksl_debug_print(float f) {
|
printf("Debug: %f\n", f);
|
}
|
|
extern "C" float sksl_clamp1(float f, float min, float max) {
|
return SkTPin(f, min, max);
|
}
|
|
using float2 = __attribute__((vector_size(8))) float;
|
using float3 = __attribute__((vector_size(16))) float;
|
using float4 = __attribute__((vector_size(16))) float;
|
|
extern "C" float2 sksl_clamp2(float2 f, float min, float max) {
|
return float2 { SkTPin(f[0], min, max), SkTPin(f[1], min, max) };
|
}
|
|
extern "C" float3 sksl_clamp3(float3 f, float min, float max) {
|
return float3 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max) };
|
}
|
|
extern "C" float4 sksl_clamp4(float4 f, float min, float max) {
|
return float4 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max),
|
SkTPin(f[3], min, max) };
|
}
|
|
namespace SkSL {
|
|
static constexpr int STAGE_PARAM_COUNT = 12;
|
|
static bool ends_with_branch(const Statement& stmt) {
|
switch (stmt.fKind) {
|
case Statement::kBlock_Kind: {
|
const Block& b = (const Block&) stmt;
|
if (b.fStatements.size()) {
|
return ends_with_branch(*b.fStatements.back());
|
}
|
return false;
|
}
|
case Statement::kBreak_Kind: // fall through
|
case Statement::kContinue_Kind: // fall through
|
case Statement::kReturn_Kind: // fall through
|
return true;
|
default:
|
return false;
|
}
|
}
|
|
JIT::JIT(Compiler* compiler)
|
: fCompiler(*compiler) {
|
LLVMInitializeNativeTarget();
|
LLVMInitializeNativeAsmPrinter();
|
LLVMLinkInMCJIT();
|
SkASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported
|
if (SkCpu::Supports(SkCpu::HSW)) {
|
fVectorCount = 8;
|
fCPU = "haswell";
|
} else if (SkCpu::Supports(SkCpu::AVX)) {
|
fVectorCount = 8;
|
fCPU = "ivybridge";
|
} else {
|
fVectorCount = 4;
|
fCPU = nullptr;
|
}
|
fContext = LLVMContextCreate();
|
fVoidType = LLVMVoidTypeInContext(fContext);
|
fInt1Type = LLVMInt1TypeInContext(fContext);
|
fInt1VectorType = LLVMVectorType(fInt1Type, fVectorCount);
|
fInt1Vector2Type = LLVMVectorType(fInt1Type, 2);
|
fInt1Vector3Type = LLVMVectorType(fInt1Type, 3);
|
fInt1Vector4Type = LLVMVectorType(fInt1Type, 4);
|
fInt8Type = LLVMInt8TypeInContext(fContext);
|
fInt8PtrType = LLVMPointerType(fInt8Type, 0);
|
fInt32Type = LLVMInt32TypeInContext(fContext);
|
fInt64Type = LLVMInt64TypeInContext(fContext);
|
fSizeTType = LLVMInt64TypeInContext(fContext);
|
fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount);
|
fInt32Vector2Type = LLVMVectorType(fInt32Type, 2);
|
fInt32Vector3Type = LLVMVectorType(fInt32Type, 3);
|
fInt32Vector4Type = LLVMVectorType(fInt32Type, 4);
|
fFloat32Type = LLVMFloatTypeInContext(fContext);
|
fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount);
|
fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2);
|
fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3);
|
fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4);
|
}
|
|
JIT::~JIT() {
|
LLVMOrcDisposeInstance(fJITStack);
|
LLVMContextDispose(fContext);
|
}
|
|
void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
|
std::vector<LLVMTypeRef> parameters) {
|
bool found = false;
|
for (const auto& pair : *fProgram->fSymbols) {
|
if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
|
const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
|
if (pair.first != ourName || returnType != this->getType(f.fReturnType) ||
|
parameters.size() != f.fParameters.size()) {
|
continue;
|
}
|
for (size_t i = 0; i < parameters.size(); ++i) {
|
if (parameters[i] != this->getType(f.fParameters[i]->fType)) {
|
goto next;
|
}
|
}
|
fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType,
|
parameters.data(),
|
parameters.size(),
|
false));
|
found = true;
|
}
|
if (Symbol::kUnresolvedFunction_Kind == pair.second->fKind) {
|
// FIXME consolidate this with the code above
|
for (const auto& f : ((const UnresolvedFunction&) *pair.second).fFunctions) {
|
if (pair.first != ourName || returnType != this->getType(f->fReturnType) ||
|
parameters.size() != f->fParameters.size()) {
|
continue;
|
}
|
for (size_t i = 0; i < parameters.size(); ++i) {
|
if (parameters[i] != this->getType(f->fParameters[i]->fType)) {
|
goto next;
|
}
|
}
|
fFunctions[f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(
|
returnType,
|
parameters.data(),
|
parameters.size(),
|
false));
|
found = true;
|
}
|
}
|
next:;
|
}
|
SkASSERT(found);
|
}
|
|
void JIT::loadBuiltinFunctions() {
|
this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type });
|
this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type });
|
this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
|
this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
|
this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
|
this->addBuiltinFunction("clamp", "sksl_clamp1", fFloat32Type, { fFloat32Type,
|
fFloat32Type,
|
fFloat32Type });
|
this->addBuiltinFunction("clamp", "sksl_clamp2", fFloat32Vector2Type, { fFloat32Vector2Type,
|
fFloat32Type,
|
fFloat32Type });
|
this->addBuiltinFunction("clamp", "sksl_clamp3", fFloat32Vector3Type, { fFloat32Vector3Type,
|
fFloat32Type,
|
fFloat32Type });
|
this->addBuiltinFunction("clamp", "sksl_clamp4", fFloat32Vector4Type, { fFloat32Vector4Type,
|
fFloat32Type,
|
fFloat32Type });
|
this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
|
}
|
|
uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
|
LLVMOrcTargetAddress result;
|
if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) {
|
if (!strcmp(name, "_sksl_pipeline_append")) {
|
result = (uint64_t) &sksl_pipeline_append;
|
} else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
|
result = (uint64_t) &sksl_pipeline_append_callback;
|
} else if (!strcmp(name, "_sksl_clamp1")) {
|
result = (uint64_t) &sksl_clamp1;
|
} else if (!strcmp(name, "_sksl_clamp2")) {
|
result = (uint64_t) &sksl_clamp2;
|
} else if (!strcmp(name, "_sksl_clamp3")) {
|
result = (uint64_t) &sksl_clamp3;
|
} else if (!strcmp(name, "_sksl_clamp4")) {
|
result = (uint64_t) &sksl_clamp4;
|
} else if (!strcmp(name, "_sksl_debug_print")) {
|
result = (uint64_t) &sksl_debug_print;
|
} else {
|
result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
|
}
|
}
|
SkASSERT(result);
|
return result;
|
}
|
|
LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) {
|
LLVMValueRef func = fFunctions[&fc.fFunction];
|
SkASSERT(func);
|
std::vector<LLVMValueRef> parameters;
|
for (const auto& a : fc.fArguments) {
|
parameters.push_back(this->compileExpression(builder, *a));
|
}
|
return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), "");
|
}
|
|
LLVMTypeRef JIT::getType(const Type& type) {
|
switch (type.kind()) {
|
case Type::kOther_Kind:
|
if (type.name() == "void") {
|
return fVoidType;
|
}
|
SkASSERT(type.name() == "SkRasterPipeline");
|
return fInt8PtrType;
|
case Type::kScalar_Kind:
|
if (type.isSigned() || type.isUnsigned()) {
|
return fInt32Type;
|
}
|
if (type.isUnsigned()) {
|
return fInt32Type;
|
}
|
if (type.isFloat()) {
|
return fFloat32Type;
|
}
|
SkASSERT(type.name() == "bool");
|
return fInt1Type;
|
case Type::kArray_Kind:
|
return LLVMPointerType(this->getType(type.componentType()), 0);
|
case Type::kVector_Kind:
|
if (type.name() == "float2" || type.name() == "half2") {
|
return fFloat32Vector2Type;
|
}
|
if (type.name() == "float3" || type.name() == "half3") {
|
return fFloat32Vector3Type;
|
}
|
if (type.name() == "float4" || type.name() == "half4") {
|
return fFloat32Vector4Type;
|
}
|
if (type.name() == "int2" || type.name() == "short2" || type.name == "byte2") {
|
return fInt32Vector2Type;
|
}
|
if (type.name() == "int3" || type.name() == "short3" || type.name == "byte3") {
|
return fInt32Vector3Type;
|
}
|
if (type.name() == "int4" || type.name() == "short4" || type.name == "byte3") {
|
return fInt32Vector4Type;
|
}
|
// fall through
|
default:
|
ABORT("unsupported type");
|
}
|
}
|
|
void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) {
|
fCurrentBlock = block;
|
LLVMPositionBuilderAtEnd(builder, block);
|
}
|
|
std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) {
|
switch (expr.fKind) {
|
case Expression::kVariableReference_Kind: {
|
class PointerLValue : public LValue {
|
public:
|
PointerLValue(LLVMValueRef ptr)
|
: fPointer(ptr) {}
|
|
LLVMValueRef load(LLVMBuilderRef builder) override {
|
return LLVMBuildLoad(builder, fPointer, "lvalue load");
|
}
|
|
void store(LLVMBuilderRef builder, LLVMValueRef value) override {
|
LLVMBuildStore(builder, value, fPointer);
|
}
|
|
private:
|
LLVMValueRef fPointer;
|
};
|
const Variable* var = &((VariableReference&) expr).fVariable;
|
if (var->fStorage == Variable::kParameter_Storage &&
|
!(var->fModifiers.fFlags & Modifiers::kOut_Flag) &&
|
fPromotedParameters.find(var) == fPromotedParameters.end()) {
|
// promote parameter to variable
|
fPromotedParameters.insert(var);
|
LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
|
LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType),
|
String(var->fName).c_str());
|
LLVMBuildStore(builder, fVariables[var], alloca);
|
LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
|
fVariables[var] = alloca;
|
}
|
LLVMValueRef ptr = fVariables[var];
|
return std::unique_ptr<LValue>(new PointerLValue(ptr));
|
}
|
case Expression::kTernary_Kind: {
|
class TernaryLValue : public LValue {
|
public:
|
TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue,
|
std::unique_ptr<LValue> ifFalse)
|
: fJIT(*jit)
|
, fTest(test)
|
, fIfTrue(std::move(ifTrue))
|
, fIfFalse(std::move(ifFalse)) {}
|
|
LLVMValueRef load(LLVMBuilderRef builder) override {
|
LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
|
fJIT.fContext,
|
fJIT.fCurrentFunction,
|
"true ? ...");
|
LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
|
fJIT.fContext,
|
fJIT.fCurrentFunction,
|
"false ? ...");
|
LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
|
fJIT.fCurrentFunction,
|
"ternary merge");
|
LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
|
fJIT.setBlock(builder, trueBlock);
|
LLVMValueRef ifTrue = fIfTrue->load(builder);
|
LLVMBuildBr(builder, merge);
|
fJIT.setBlock(builder, falseBlock);
|
LLVMValueRef ifFalse = fIfTrue->load(builder);
|
LLVMBuildBr(builder, merge);
|
fJIT.setBlock(builder, merge);
|
LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0);
|
LLVMValueRef phi = LLVMBuildPhi(builder, type, "?");
|
LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
|
LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
|
LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
|
return phi;
|
}
|
|
void store(LLVMBuilderRef builder, LLVMValueRef value) override {
|
LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
|
fJIT.fContext,
|
fJIT.fCurrentFunction,
|
"true ? ...");
|
LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
|
fJIT.fContext,
|
fJIT.fCurrentFunction,
|
"false ? ...");
|
LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
|
fJIT.fCurrentFunction,
|
"ternary merge");
|
LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
|
fJIT.setBlock(builder, trueBlock);
|
fIfTrue->store(builder, value);
|
LLVMBuildBr(builder, merge);
|
fJIT.setBlock(builder, falseBlock);
|
fIfTrue->store(builder, value);
|
LLVMBuildBr(builder, merge);
|
fJIT.setBlock(builder, merge);
|
}
|
|
private:
|
JIT& fJIT;
|
LLVMValueRef fTest;
|
std::unique_ptr<LValue> fIfTrue;
|
std::unique_ptr<LValue> fIfFalse;
|
};
|
const TernaryExpression& t = (const TernaryExpression&) expr;
|
LLVMValueRef test = this->compileExpression(builder, *t.fTest);
|
return std::unique_ptr<LValue>(new TernaryLValue(this,
|
test,
|
this->getLValue(builder,
|
*t.fIfTrue),
|
this->getLValue(builder,
|
*t.fIfFalse)));
|
}
|
case Expression::kSwizzle_Kind: {
|
class SwizzleLValue : public LValue {
|
public:
|
SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base,
|
std::vector<int> components)
|
: fJIT(*jit)
|
, fType(type)
|
, fBase(std::move(base))
|
, fComponents(components) {}
|
|
LLVMValueRef load(LLVMBuilderRef builder) override {
|
LLVMValueRef base = fBase->load(builder);
|
if (fComponents.size() > 1) {
|
LLVMValueRef result = LLVMGetUndef(fType);
|
for (size_t i = 0; i < fComponents.size(); ++i) {
|
LLVMValueRef element = LLVMBuildExtractElement(
|
builder,
|
base,
|
LLVMConstInt(fJIT.fInt32Type,
|
fComponents[i],
|
false),
|
"swizzle extract");
|
result = LLVMBuildInsertElement(builder, result, element,
|
LLVMConstInt(fJIT.fInt32Type, i, false),
|
"swizzle insert");
|
}
|
return result;
|
}
|
SkASSERT(fComponents.size() == 1);
|
return LLVMBuildExtractElement(builder, base,
|
LLVMConstInt(fJIT.fInt32Type,
|
fComponents[0],
|
false),
|
"swizzle extract");
|
}
|
|
void store(LLVMBuilderRef builder, LLVMValueRef value) override {
|
LLVMValueRef result = fBase->load(builder);
|
if (fComponents.size() > 1) {
|
for (size_t i = 0; i < fComponents.size(); ++i) {
|
LLVMValueRef element = LLVMBuildExtractElement(builder, value,
|
LLVMConstInt(
|
fJIT.fInt32Type,
|
i,
|
false),
|
"swizzle extract");
|
result = LLVMBuildInsertElement(builder, result, element,
|
LLVMConstInt(fJIT.fInt32Type,
|
fComponents[i],
|
false),
|
"swizzle insert");
|
}
|
} else {
|
result = LLVMBuildInsertElement(builder, result, value,
|
LLVMConstInt(fJIT.fInt32Type,
|
fComponents[0],
|
false),
|
"swizzle insert");
|
}
|
fBase->store(builder, result);
|
}
|
|
private:
|
JIT& fJIT;
|
LLVMTypeRef fType;
|
std::unique_ptr<LValue> fBase;
|
std::vector<int> fComponents;
|
};
|
const Swizzle& s = (const Swizzle&) expr;
|
return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType),
|
this->getLValue(builder, *s.fBase),
|
s.fComponents));
|
}
|
default:
|
ABORT("unsupported lvalue");
|
}
|
}
|
|
JIT::TypeKind JIT::typeKind(const Type& type) {
|
if (type.kind() == Type::kVector_Kind) {
|
return this->typeKind(type.componentType());
|
}
|
if (type.fName == "int" || type.fName == "short" || type.fName == "byte") {
|
return JIT::kInt_TypeKind;
|
} else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") {
|
return JIT::kUInt_TypeKind;
|
} else if (type.fName == "float" || type.fName == "double" || type.fName == "half") {
|
return JIT::kFloat_TypeKind;
|
}
|
ABORT("unsupported type: %s\n", type.description().c_str());
|
}
|
|
void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) {
|
LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns));
|
for (int i = 0; i < columns; ++i) {
|
result = LLVMBuildInsertElement(builder,
|
result,
|
*value,
|
LLVMConstInt(fInt32Type, i, false),
|
"vectorize");
|
}
|
*value = result;
|
}
|
|
void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
|
LLVMValueRef* right) {
|
if (b.fLeft->fType.kind() == Type::kScalar_Kind &&
|
b.fRight->fType.kind() == Type::kVector_Kind) {
|
this->vectorize(builder, left, b.fRight->fType.columns());
|
} else if (b.fLeft->fType.kind() == Type::kVector_Kind &&
|
b.fRight->fType.kind() == Type::kScalar_Kind) {
|
this->vectorize(builder, right, b.fLeft->fType.columns());
|
}
|
}
|
|
|
LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) {
|
#define BINARY(SFunc, UFunc, FFunc) { \
|
LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
|
LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
|
this->vectorize(builder, b, &left, &right); \
|
switch (this->typeKind(b.fLeft->fType)) { \
|
case kInt_TypeKind: \
|
return SFunc(builder, left, right, "binary"); \
|
case kUInt_TypeKind: \
|
return UFunc(builder, left, right, "binary"); \
|
case kFloat_TypeKind: \
|
return FFunc(builder, left, right, "binary"); \
|
default: \
|
ABORT("unsupported typeKind"); \
|
} \
|
}
|
#define COMPOUND(SFunc, UFunc, FFunc) { \
|
std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \
|
LLVMValueRef left = lvalue->load(builder); \
|
LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
|
this->vectorize(builder, b, &left, &right); \
|
LLVMValueRef result; \
|
switch (this->typeKind(b.fLeft->fType)) { \
|
case kInt_TypeKind: \
|
result = SFunc(builder, left, right, "binary"); \
|
break; \
|
case kUInt_TypeKind: \
|
result = UFunc(builder, left, right, "binary"); \
|
break; \
|
case kFloat_TypeKind: \
|
result = FFunc(builder, left, right, "binary"); \
|
break; \
|
default: \
|
ABORT("unsupported typeKind"); \
|
} \
|
lvalue->store(builder, result); \
|
return result; \
|
}
|
#define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) { \
|
LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
|
LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
|
this->vectorize(builder, b, &left, &right); \
|
switch (this->typeKind(b.fLeft->fType)) { \
|
case kInt_TypeKind: \
|
return SFunc(builder, SOp, left, right, "binary"); \
|
case kUInt_TypeKind: \
|
return UFunc(builder, UOp, left, right, "binary"); \
|
case kFloat_TypeKind: \
|
return FFunc(builder, FOp, left, right, "binary"); \
|
default: \
|
ABORT("unsupported typeKind"); \
|
} \
|
}
|
switch (b.fOperator) {
|
case Token::EQ: {
|
std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft);
|
LLVMValueRef result = this->compileExpression(builder, *b.fRight);
|
lvalue->store(builder, result);
|
return result;
|
}
|
case Token::PLUS:
|
BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
|
case Token::MINUS:
|
BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
|
case Token::STAR:
|
BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
|
case Token::SLASH:
|
BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
|
case Token::PERCENT:
|
BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
|
case Token::BITWISEAND:
|
BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
|
case Token::BITWISEOR:
|
BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
|
case Token::SHL:
|
BINARY(LLVMBuildShl, LLVMBuildShl, LLVMBuildShl);
|
case Token::SHR:
|
BINARY(LLVMBuildAShr, LLVMBuildLShr, LLVMBuildAShr);
|
case Token::PLUSEQ:
|
COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
|
case Token::MINUSEQ:
|
COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
|
case Token::STAREQ:
|
COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
|
case Token::SLASHEQ:
|
COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
|
case Token::BITWISEANDEQ:
|
COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
|
case Token::BITWISEOREQ:
|
COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
|
case Token::EQEQ:
|
switch (b.fLeft->fType.kind()) {
|
case Type::kScalar_Kind:
|
COMPARE(LLVMBuildICmp, LLVMIntEQ,
|
LLVMBuildICmp, LLVMIntEQ,
|
LLVMBuildFCmp, LLVMRealOEQ);
|
case Type::kVector_Kind: {
|
LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
|
LLVMValueRef right = this->compileExpression(builder, *b.fRight);
|
this->vectorize(builder, b, &left, &right);
|
LLVMValueRef value;
|
switch (this->typeKind(b.fLeft->fType)) {
|
case kInt_TypeKind:
|
value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
|
break;
|
case kUInt_TypeKind:
|
value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
|
break;
|
case kFloat_TypeKind:
|
value = LLVMBuildFCmp(builder, LLVMRealOEQ, left, right, "binary");
|
break;
|
default:
|
ABORT("unsupported typeKind");
|
}
|
LLVMValueRef args[1] = { value };
|
LLVMValueRef func;
|
switch (b.fLeft->fType.columns()) {
|
case 2: func = fFoldAnd2Func; break;
|
case 3: func = fFoldAnd3Func; break;
|
case 4: func = fFoldAnd4Func; break;
|
default:
|
SkASSERT(false);
|
func = fFoldAnd2Func;
|
}
|
return LLVMBuildCall(builder, func, args, 1, "all");
|
}
|
default:
|
SkASSERT(false);
|
}
|
case Token::NEQ:
|
switch (b.fLeft->fType.kind()) {
|
case Type::kScalar_Kind:
|
COMPARE(LLVMBuildICmp, LLVMIntNE,
|
LLVMBuildICmp, LLVMIntNE,
|
LLVMBuildFCmp, LLVMRealONE);
|
case Type::kVector_Kind: {
|
LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
|
LLVMValueRef right = this->compileExpression(builder, *b.fRight);
|
this->vectorize(builder, b, &left, &right);
|
LLVMValueRef value;
|
switch (this->typeKind(b.fLeft->fType)) {
|
case kInt_TypeKind:
|
value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
|
break;
|
case kUInt_TypeKind:
|
value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
|
break;
|
case kFloat_TypeKind:
|
value = LLVMBuildFCmp(builder, LLVMRealONE, left, right, "binary");
|
break;
|
default:
|
ABORT("unsupported typeKind");
|
}
|
LLVMValueRef args[1] = { value };
|
LLVMValueRef func;
|
switch (b.fLeft->fType.columns()) {
|
case 2: func = fFoldOr2Func; break;
|
case 3: func = fFoldOr3Func; break;
|
case 4: func = fFoldOr4Func; break;
|
default:
|
SkASSERT(false);
|
func = fFoldOr2Func;
|
}
|
return LLVMBuildCall(builder, func, args, 1, "all");
|
}
|
default:
|
SkASSERT(false);
|
}
|
case Token::LT:
|
COMPARE(LLVMBuildICmp, LLVMIntSLT,
|
LLVMBuildICmp, LLVMIntULT,
|
LLVMBuildFCmp, LLVMRealOLT);
|
case Token::LTEQ:
|
COMPARE(LLVMBuildICmp, LLVMIntSLE,
|
LLVMBuildICmp, LLVMIntULE,
|
LLVMBuildFCmp, LLVMRealOLE);
|
case Token::GT:
|
COMPARE(LLVMBuildICmp, LLVMIntSGT,
|
LLVMBuildICmp, LLVMIntUGT,
|
LLVMBuildFCmp, LLVMRealOGT);
|
case Token::GTEQ:
|
COMPARE(LLVMBuildICmp, LLVMIntSGE,
|
LLVMBuildICmp, LLVMIntUGE,
|
LLVMBuildFCmp, LLVMRealOGE);
|
case Token::LOGICALAND: {
|
LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
|
LLVMBasicBlockRef ifFalse = fCurrentBlock;
|
LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"true && ...");
|
LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"&& merge");
|
LLVMBuildCondBr(builder, left, ifTrue, merge);
|
this->setBlock(builder, ifTrue);
|
LLVMValueRef right = this->compileExpression(builder, *b.fRight);
|
LLVMBuildBr(builder, merge);
|
this->setBlock(builder, merge);
|
LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&");
|
LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) };
|
LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse };
|
LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
|
return phi;
|
}
|
case Token::LOGICALOR: {
|
LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
|
LLVMBasicBlockRef ifTrue = fCurrentBlock;
|
LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"false || ...");
|
LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"|| merge");
|
LLVMBuildCondBr(builder, left, merge, ifFalse);
|
this->setBlock(builder, ifFalse);
|
LLVMValueRef right = this->compileExpression(builder, *b.fRight);
|
LLVMBuildBr(builder, merge);
|
this->setBlock(builder, merge);
|
LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||");
|
LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) };
|
LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue };
|
LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
|
return phi;
|
}
|
default:
|
printf("%s\n", b.description().c_str());
|
ABORT("unsupported binary operator");
|
}
|
}
|
|
LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) {
|
LLVMValueRef base = this->compileExpression(builder, *idx.fBase);
|
LLVMValueRef index = this->compileExpression(builder, *idx.fIndex);
|
LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr");
|
return LLVMBuildLoad(builder, ptr, "index load");
|
}
|
|
LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) {
|
std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
|
LLVMValueRef result = lvalue->load(builder);
|
LLVMValueRef mod;
|
LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
|
switch (p.fOperator) {
|
case Token::PLUSPLUS:
|
switch (this->typeKind(p.fType)) {
|
case kInt_TypeKind: // fall through
|
case kUInt_TypeKind:
|
mod = LLVMBuildAdd(builder, result, one, "++");
|
break;
|
case kFloat_TypeKind:
|
mod = LLVMBuildFAdd(builder, result, one, "++");
|
break;
|
default:
|
ABORT("unsupported typeKind");
|
}
|
break;
|
case Token::MINUSMINUS:
|
switch (this->typeKind(p.fType)) {
|
case kInt_TypeKind: // fall through
|
case kUInt_TypeKind:
|
mod = LLVMBuildSub(builder, result, one, "--");
|
break;
|
case kFloat_TypeKind:
|
mod = LLVMBuildFSub(builder, result, one, "--");
|
break;
|
default:
|
ABORT("unsupported typeKind");
|
}
|
break;
|
default:
|
ABORT("unsupported postfix op");
|
}
|
lvalue->store(builder, mod);
|
return result;
|
}
|
|
LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) {
|
LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
|
if (Token::LOGICALNOT == p.fOperator) {
|
LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
|
return LLVMBuildXor(builder, base, one, "!");
|
}
|
if (Token::MINUS == p.fOperator) {
|
LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
|
return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-");
|
}
|
std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
|
LLVMValueRef raw = lvalue->load(builder);
|
LLVMValueRef result;
|
switch (p.fOperator) {
|
case Token::PLUSPLUS:
|
switch (this->typeKind(p.fType)) {
|
case kInt_TypeKind: // fall through
|
case kUInt_TypeKind:
|
result = LLVMBuildAdd(builder, raw, one, "++");
|
break;
|
case kFloat_TypeKind:
|
result = LLVMBuildFAdd(builder, raw, one, "++");
|
break;
|
default:
|
ABORT("unsupported typeKind");
|
}
|
break;
|
case Token::MINUSMINUS:
|
switch (this->typeKind(p.fType)) {
|
case kInt_TypeKind: // fall through
|
case kUInt_TypeKind:
|
result = LLVMBuildSub(builder, raw, one, "--");
|
break;
|
case kFloat_TypeKind:
|
result = LLVMBuildFSub(builder, raw, one, "--");
|
break;
|
default:
|
ABORT("unsupported typeKind");
|
}
|
break;
|
default:
|
ABORT("unsupported prefix op");
|
}
|
lvalue->store(builder, result);
|
return result;
|
}
|
|
LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) {
|
const Variable& var = v.fVariable;
|
if (Variable::kParameter_Storage == var.fStorage &&
|
!(var.fModifiers.fFlags & Modifiers::kOut_Flag) &&
|
fPromotedParameters.find(&var) == fPromotedParameters.end()) {
|
return fVariables[&var];
|
}
|
return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str());
|
}
|
|
void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
|
SkASSERT(a.fArguments.size() >= 1);
|
SkASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type);
|
LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]);
|
LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0);
|
switch (a.fStage) {
|
case SkRasterPipeline::callback: {
|
SkASSERT(a.fArguments.size() == 2);
|
SkASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind);
|
const FunctionDeclaration& functionDecl =
|
*((FunctionReference&) *a.fArguments[1]).fFunctions[0];
|
bool found = false;
|
for (const auto& pe : *fProgram) {
|
if (ProgramElement::kFunction_Kind == pe.fKind) {
|
const FunctionDefinition& def = (const FunctionDefinition&) pe;
|
if (&def.fDeclaration == &functionDecl) {
|
LLVMValueRef fn = this->compileStageFunction(def);
|
LLVMValueRef args[2] = {
|
pipeline,
|
LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast")
|
};
|
LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, "");
|
found = true;
|
break;
|
}
|
}
|
}
|
SkASSERT(found);
|
break;
|
}
|
default: {
|
LLVMValueRef ctx;
|
if (a.fArguments.size() == 2) {
|
ctx = this->compileExpression(builder, *a.fArguments[1]);
|
ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast");
|
} else {
|
SkASSERT(a.fArguments.size() == 1);
|
ctx = LLVMConstNull(fInt8PtrType);
|
}
|
LLVMValueRef args[3] = {
|
pipeline,
|
stage,
|
ctx
|
};
|
LLVMBuildCall(builder, fAppendFunc, args, 3, "");
|
break;
|
}
|
}
|
}
|
|
LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) {
|
switch (c.fType.kind()) {
|
case Type::kScalar_Kind: {
|
SkASSERT(c.fArguments.size() == 1);
|
TypeKind from = this->typeKind(c.fArguments[0]->fType);
|
TypeKind to = this->typeKind(c.fType);
|
LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
|
switch (to) {
|
case kFloat_TypeKind:
|
switch (from) {
|
case kInt_TypeKind:
|
return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
|
case kUInt_TypeKind:
|
return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
|
case kFloat_TypeKind:
|
return base;
|
case kBool_TypeKind:
|
SkASSERT(false);
|
}
|
case kInt_TypeKind:
|
switch (from) {
|
case kInt_TypeKind:
|
return base;
|
case kUInt_TypeKind:
|
return base;
|
case kFloat_TypeKind:
|
return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
|
case kBool_TypeKind:
|
SkASSERT(false);
|
}
|
case kUInt_TypeKind:
|
switch (from) {
|
case kInt_TypeKind:
|
return base;
|
case kUInt_TypeKind:
|
return base;
|
case kFloat_TypeKind:
|
return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
|
case kBool_TypeKind:
|
SkASSERT(false);
|
}
|
case kBool_TypeKind:
|
SkASSERT(false);
|
}
|
}
|
case Type::kVector_Kind: {
|
LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
|
if (c.fArguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
|
LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
|
for (int i = 0; i < c.fType.columns(); ++i) {
|
vec = LLVMBuildInsertElement(builder, vec, value,
|
LLVMConstInt(fInt32Type, i, false),
|
"vec build 1");
|
}
|
} else {
|
int index = 0;
|
for (const auto& arg : c.fArguments) {
|
LLVMValueRef value = this->compileExpression(builder, *arg);
|
if (arg->fType.kind() == Type::kVector_Kind) {
|
for (int i = 0; i < arg->fType.columns(); ++i) {
|
LLVMValueRef column = LLVMBuildExtractElement(builder,
|
vec,
|
LLVMConstInt(fInt32Type,
|
i,
|
false),
|
"construct extract");
|
vec = LLVMBuildInsertElement(builder, vec, column,
|
LLVMConstInt(fInt32Type, index++, false),
|
"vec build 2");
|
}
|
} else {
|
vec = LLVMBuildInsertElement(builder, vec, value,
|
LLVMConstInt(fInt32Type, index++, false),
|
"vec build 3");
|
}
|
}
|
}
|
return vec;
|
}
|
default:
|
break;
|
}
|
ABORT("unsupported constructor");
|
}
|
|
LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) {
|
LLVMValueRef base = this->compileExpression(builder, *s.fBase);
|
if (s.fComponents.size() > 1) {
|
LLVMValueRef result = LLVMGetUndef(this->getType(s.fType));
|
for (size_t i = 0; i < s.fComponents.size(); ++i) {
|
LLVMValueRef element = LLVMBuildExtractElement(
|
builder,
|
base,
|
LLVMConstInt(fInt32Type,
|
s.fComponents[i],
|
false),
|
"swizzle extract");
|
result = LLVMBuildInsertElement(builder, result, element,
|
LLVMConstInt(fInt32Type, i, false),
|
"swizzle insert");
|
}
|
return result;
|
}
|
SkASSERT(s.fComponents.size() == 1);
|
return LLVMBuildExtractElement(builder, base,
|
LLVMConstInt(fInt32Type,
|
s.fComponents[0],
|
false),
|
"swizzle extract");
|
}
|
|
LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) {
|
LLVMValueRef test = this->compileExpression(builder, *t.fTest);
|
LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"if true");
|
LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"if merge");
|
LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"if false");
|
LLVMBuildCondBr(builder, test, trueBlock, falseBlock);
|
this->setBlock(builder, trueBlock);
|
LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue);
|
trueBlock = fCurrentBlock;
|
LLVMBuildBr(builder, merge);
|
this->setBlock(builder, falseBlock);
|
LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse);
|
falseBlock = fCurrentBlock;
|
LLVMBuildBr(builder, merge);
|
this->setBlock(builder, merge);
|
LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?");
|
LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
|
LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
|
LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
|
return phi;
|
}
|
|
LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) {
|
switch (expr.fKind) {
|
case Expression::kAppendStage_Kind: {
|
this->appendStage(builder, (const AppendStage&) expr);
|
return LLVMValueRef();
|
}
|
case Expression::kBinary_Kind:
|
return this->compileBinary(builder, (BinaryExpression&) expr);
|
case Expression::kBoolLiteral_Kind:
|
return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false);
|
case Expression::kConstructor_Kind:
|
return this->compileConstructor(builder, (Constructor&) expr);
|
case Expression::kIntLiteral_Kind:
|
return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true);
|
case Expression::kFieldAccess_Kind:
|
abort();
|
case Expression::kFloatLiteral_Kind:
|
return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue);
|
case Expression::kFunctionCall_Kind:
|
return this->compileFunctionCall(builder, (FunctionCall&) expr);
|
case Expression::kIndex_Kind:
|
return this->compileIndex(builder, (IndexExpression&) expr);
|
case Expression::kPrefix_Kind:
|
return this->compilePrefix(builder, (PrefixExpression&) expr);
|
case Expression::kPostfix_Kind:
|
return this->compilePostfix(builder, (PostfixExpression&) expr);
|
case Expression::kSetting_Kind:
|
abort();
|
case Expression::kSwizzle_Kind:
|
return this->compileSwizzle(builder, (Swizzle&) expr);
|
case Expression::kVariableReference_Kind:
|
return this->compileVariableReference(builder, (VariableReference&) expr);
|
case Expression::kTernary_Kind:
|
return this->compileTernary(builder, (TernaryExpression&) expr);
|
case Expression::kTypeReference_Kind:
|
abort();
|
default:
|
abort();
|
}
|
ABORT("unsupported expression: %s\n", expr.description().c_str());
|
}
|
|
void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) {
|
for (const auto& stmt : block.fStatements) {
|
this->compileStatement(builder, *stmt);
|
}
|
}
|
|
void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) {
|
for (const auto& declStatement : decls.fDeclaration->fVars) {
|
const VarDeclaration& decl = (VarDeclaration&) *declStatement;
|
LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
|
LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType),
|
String(decl.fVar->fName).c_str());
|
fVariables[decl.fVar] = alloca;
|
LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
|
if (decl.fValue) {
|
LLVMValueRef result = this->compileExpression(builder, *decl.fValue);
|
LLVMBuildStore(builder, result, alloca);
|
}
|
}
|
}
|
|
void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) {
|
LLVMValueRef test = this->compileExpression(builder, *i.fTest);
|
LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true");
|
LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"if merge");
|
LLVMBasicBlockRef ifFalse;
|
if (i.fIfFalse) {
|
ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false");
|
} else {
|
ifFalse = merge;
|
}
|
LLVMBuildCondBr(builder, test, ifTrue, ifFalse);
|
this->setBlock(builder, ifTrue);
|
this->compileStatement(builder, *i.fIfTrue);
|
if (!ends_with_branch(*i.fIfTrue)) {
|
LLVMBuildBr(builder, merge);
|
}
|
if (i.fIfFalse) {
|
this->setBlock(builder, ifFalse);
|
this->compileStatement(builder, *i.fIfFalse);
|
if (!ends_with_branch(*i.fIfFalse)) {
|
LLVMBuildBr(builder, merge);
|
}
|
}
|
this->setBlock(builder, merge);
|
}
|
|
void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) {
|
if (f.fInitializer) {
|
this->compileStatement(builder, *f.fInitializer);
|
}
|
LLVMBasicBlockRef start;
|
LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body");
|
LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next");
|
LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end");
|
if (f.fTest) {
|
start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test");
|
LLVMBuildBr(builder, start);
|
this->setBlock(builder, start);
|
LLVMValueRef test = this->compileExpression(builder, *f.fTest);
|
LLVMBuildCondBr(builder, test, body, end);
|
} else {
|
start = body;
|
LLVMBuildBr(builder, body);
|
}
|
this->setBlock(builder, body);
|
fBreakTarget.push_back(end);
|
fContinueTarget.push_back(next);
|
this->compileStatement(builder, *f.fStatement);
|
fBreakTarget.pop_back();
|
fContinueTarget.pop_back();
|
if (!ends_with_branch(*f.fStatement)) {
|
LLVMBuildBr(builder, next);
|
}
|
this->setBlock(builder, next);
|
if (f.fNext) {
|
this->compileExpression(builder, *f.fNext);
|
}
|
LLVMBuildBr(builder, start);
|
this->setBlock(builder, end);
|
}
|
|
void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) {
|
LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"do test");
|
LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"do body");
|
LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"do end");
|
LLVMBuildBr(builder, body);
|
this->setBlock(builder, testBlock);
|
LLVMValueRef test = this->compileExpression(builder, *d.fTest);
|
LLVMBuildCondBr(builder, test, body, end);
|
this->setBlock(builder, body);
|
fBreakTarget.push_back(end);
|
fContinueTarget.push_back(body);
|
this->compileStatement(builder, *d.fStatement);
|
fBreakTarget.pop_back();
|
fContinueTarget.pop_back();
|
if (!ends_with_branch(*d.fStatement)) {
|
LLVMBuildBr(builder, testBlock);
|
}
|
this->setBlock(builder, end);
|
}
|
|
void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) {
|
LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"while test");
|
LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"while body");
|
LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
|
"while end");
|
LLVMBuildBr(builder, testBlock);
|
this->setBlock(builder, testBlock);
|
LLVMValueRef test = this->compileExpression(builder, *w.fTest);
|
LLVMBuildCondBr(builder, test, body, end);
|
this->setBlock(builder, body);
|
fBreakTarget.push_back(end);
|
fContinueTarget.push_back(testBlock);
|
this->compileStatement(builder, *w.fStatement);
|
fBreakTarget.pop_back();
|
fContinueTarget.pop_back();
|
if (!ends_with_branch(*w.fStatement)) {
|
LLVMBuildBr(builder, testBlock);
|
}
|
this->setBlock(builder, end);
|
}
|
|
void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) {
|
LLVMBuildBr(builder, fBreakTarget.back());
|
}
|
|
void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) {
|
LLVMBuildBr(builder, fContinueTarget.back());
|
}
|
|
void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) {
|
if (r.fExpression) {
|
LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression));
|
} else {
|
LLVMBuildRetVoid(builder);
|
}
|
}
|
|
void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) {
|
switch (stmt.fKind) {
|
case Statement::kBlock_Kind:
|
this->compileBlock(builder, (Block&) stmt);
|
break;
|
case Statement::kBreak_Kind:
|
this->compileBreak(builder, (BreakStatement&) stmt);
|
break;
|
case Statement::kContinue_Kind:
|
this->compileContinue(builder, (ContinueStatement&) stmt);
|
break;
|
case Statement::kDiscard_Kind:
|
abort();
|
case Statement::kDo_Kind:
|
this->compileDo(builder, (DoStatement&) stmt);
|
break;
|
case Statement::kExpression_Kind:
|
this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression);
|
break;
|
case Statement::kFor_Kind:
|
this->compileFor(builder, (ForStatement&) stmt);
|
break;
|
case Statement::kGroup_Kind:
|
abort();
|
case Statement::kIf_Kind:
|
this->compileIf(builder, (IfStatement&) stmt);
|
break;
|
case Statement::kNop_Kind:
|
break;
|
case Statement::kReturn_Kind:
|
this->compileReturn(builder, (ReturnStatement&) stmt);
|
break;
|
case Statement::kSwitch_Kind:
|
abort();
|
case Statement::kVarDeclarations_Kind:
|
this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt);
|
break;
|
case Statement::kWhile_Kind:
|
this->compileWhile(builder, (WhileStatement&) stmt);
|
break;
|
default:
|
abort();
|
}
|
}
|
|
void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) {
|
// loop over fVectorCount pixels, running the body of the stage function for each of them
|
LLVMValueRef oldFunction = fCurrentFunction;
|
fCurrentFunction = newFunc;
|
std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
|
LLVMGetParams(fCurrentFunction, params.get());
|
LLVMValueRef programParam = params.get()[1];
|
LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
|
LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
|
LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
|
fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
|
this->setBlock(builder, fAllocaBlock);
|
// temporaries to store the color channel vectors
|
LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
|
LLVMBuildStore(builder, params.get()[4], rVec);
|
LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
|
LLVMBuildStore(builder, params.get()[5], gVec);
|
LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
|
LLVMBuildStore(builder, params.get()[6], bVec);
|
LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
|
LLVMBuildStore(builder, params.get()[7], aVec);
|
LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color");
|
fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type,
|
"y->Int32");
|
fVariables[f.fDeclaration.fParameters[2]] = color;
|
LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i");
|
LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar);
|
LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
|
this->setBlock(builder, start);
|
LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i");
|
fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder,
|
LLVMBuildTrunc(builder,
|
params.get()[2],
|
fInt32Type,
|
"x->Int32"),
|
iload,
|
"x");
|
LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false);
|
LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize");
|
LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body");
|
LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end");
|
LLVMBuildCondBr(builder, test, loopBody, loopEnd);
|
this->setBlock(builder, loopBody);
|
LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type);
|
// extract the r, g, b, and a values from the color channel vectors and store them into "color"
|
for (int i = 0; i < 4; ++i) {
|
vec = LLVMBuildInsertElement(builder, vec,
|
LLVMBuildExtractElement(builder,
|
params.get()[4 + i],
|
iload, "initial"),
|
LLVMConstInt(fInt32Type, i, false),
|
"vec build");
|
}
|
LLVMBuildStore(builder, vec, color);
|
// write actual loop body
|
this->compileStatement(builder, *f.fBody);
|
// extract the r, g, b, and a values from "color" and stick them back into the color channel
|
// vectors
|
LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load");
|
LLVMBuildStore(builder,
|
LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"),
|
LLVMBuildExtractElement(builder, colorLoad,
|
LLVMConstInt(fInt32Type, 0,
|
false),
|
"rExtract"),
|
iload, "rInsert"),
|
rVec);
|
LLVMBuildStore(builder,
|
LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"),
|
LLVMBuildExtractElement(builder, colorLoad,
|
LLVMConstInt(fInt32Type, 1,
|
false),
|
"gExtract"),
|
iload, "gInsert"),
|
gVec);
|
LLVMBuildStore(builder,
|
LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"),
|
LLVMBuildExtractElement(builder, colorLoad,
|
LLVMConstInt(fInt32Type, 2,
|
false),
|
"bExtract"),
|
iload, "bInsert"),
|
bVec);
|
LLVMBuildStore(builder,
|
LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"),
|
LLVMBuildExtractElement(builder, colorLoad,
|
LLVMConstInt(fInt32Type, 3,
|
false),
|
"aExtract"),
|
iload, "aInsert"),
|
aVec);
|
LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i");
|
LLVMBuildStore(builder, inc, ivar);
|
LLVMBuildBr(builder, start);
|
this->setBlock(builder, loopEnd);
|
// increment program pointer, call the next stage
|
LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
|
LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
|
LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func");
|
LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
|
LLVMBuildAdd(builder,
|
LLVMBuildPtrToInt(builder,
|
programParam,
|
fInt64Type,
|
"cast 1"),
|
LLVMConstInt(fInt64Type, PTR_SIZE, false),
|
"add"),
|
LLVMPointerType(fInt8PtrType, 0), "cast 2");
|
LLVMValueRef args[STAGE_PARAM_COUNT] = {
|
params.get()[0],
|
nextInc,
|
params.get()[2],
|
params.get()[3],
|
LLVMBuildLoad(builder, rVec, "rVec"),
|
LLVMBuildLoad(builder, gVec, "gVec"),
|
LLVMBuildLoad(builder, bVec, "bVec"),
|
LLVMBuildLoad(builder, aVec, "aVec"),
|
params.get()[8],
|
params.get()[9],
|
params.get()[10],
|
params.get()[11]
|
};
|
LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
|
LLVMBuildRetVoid(builder);
|
// finish
|
LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
|
LLVMBuildBr(builder, start);
|
LLVMDisposeBuilder(builder);
|
if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
|
ABORT("verify failed\n");
|
}
|
fAllocaBlock = oldAllocaBlock;
|
fCurrentBlock = oldCurrentBlock;
|
fCurrentFunction = oldFunction;
|
}
|
|
// FIXME maybe pluggable code generators? Need to do something to separate all
|
// of the normal codegen from the vector codegen and break this up into multiple
|
// classes.
|
|
bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e,
|
LLVMValueRef out[CHANNELS]) {
|
switch (e.fKind) {
|
case Expression::kVariableReference_Kind:
|
if (fColorParam == &((VariableReference&) e).fVariable) {
|
memcpy(out, fChannels, sizeof(fChannels));
|
return true;
|
}
|
return false;
|
case Expression::kSwizzle_Kind: {
|
const Swizzle& s = (const Swizzle&) e;
|
LLVMValueRef base[CHANNELS];
|
if (!this->getVectorLValue(builder, *s.fBase, base)) {
|
return false;
|
}
|
for (size_t i = 0; i < s.fComponents.size(); ++i) {
|
out[i] = base[s.fComponents[i]];
|
}
|
return true;
|
}
|
default:
|
return false;
|
}
|
}
|
|
bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
|
LLVMValueRef outLeft[CHANNELS], const Expression& right,
|
LLVMValueRef outRight[CHANNELS]) {
|
if (!this->compileVectorExpression(builder, left, outLeft)) {
|
return false;
|
}
|
int leftColumns = left.fType.columns();
|
int rightColumns = right.fType.columns();
|
if (leftColumns == 1 && rightColumns > 1) {
|
for (int i = 1; i < rightColumns; ++i) {
|
outLeft[i] = outLeft[0];
|
}
|
}
|
if (!this->compileVectorExpression(builder, right, outRight)) {
|
return false;
|
}
|
if (rightColumns == 1 && leftColumns > 1) {
|
for (int i = 1; i < leftColumns; ++i) {
|
outRight[i] = outRight[0];
|
}
|
}
|
return true;
|
}
|
|
bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
|
LLVMValueRef out[CHANNELS]) {
|
LLVMValueRef left[CHANNELS];
|
LLVMValueRef right[CHANNELS];
|
#define VECTOR_BINARY(signedOp, unsignedOp, floatOp) { \
|
if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \
|
return false; \
|
} \
|
for (int i = 0; i < b.fLeft->fType.columns(); ++i) { \
|
switch (this->typeKind(b.fLeft->fType)) { \
|
case kInt_TypeKind: \
|
out[i] = signedOp(builder, left[i], right[i], "binary"); \
|
break; \
|
case kUInt_TypeKind: \
|
out[i] = unsignedOp(builder, left[i], right[i], "binary"); \
|
break; \
|
case kFloat_TypeKind: \
|
out[i] = floatOp(builder, left[i], right[i], "binary"); \
|
break; \
|
case kBool_TypeKind: \
|
SkASSERT(false); \
|
break; \
|
} \
|
} \
|
return true; \
|
}
|
switch (b.fOperator) {
|
case Token::EQ: {
|
if (!this->getVectorLValue(builder, *b.fLeft, left)) {
|
return false;
|
}
|
if (!this->compileVectorExpression(builder, *b.fRight, right)) {
|
return false;
|
}
|
int columns = b.fRight->fType.columns();
|
for (int i = 0; i < columns; ++i) {
|
LLVMBuildStore(builder, right[i], left[i]);
|
}
|
return true;
|
}
|
case Token::PLUS:
|
VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
|
case Token::MINUS:
|
VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
|
case Token::STAR:
|
VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
|
case Token::SLASH:
|
VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
|
case Token::PERCENT:
|
VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
|
case Token::BITWISEAND:
|
VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
|
case Token::BITWISEOR:
|
VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
|
default:
|
printf("unsupported operator: %s\n", b.description().c_str());
|
return false;
|
}
|
}
|
|
bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
|
LLVMValueRef out[CHANNELS]) {
|
switch (c.fType.kind()) {
|
case Type::kScalar_Kind: {
|
SkASSERT(c.fArguments.size() == 1);
|
TypeKind from = this->typeKind(c.fArguments[0]->fType);
|
TypeKind to = this->typeKind(c.fType);
|
LLVMValueRef base[CHANNELS];
|
if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
|
return false;
|
}
|
#define CONSTRUCT(fn) \
|
out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount)); \
|
for (int i = 0; i < fVectorCount; ++i) { \
|
LLVMValueRef index = LLVMConstInt(fInt32Type, i, false); \
|
LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index, \
|
"construct extract"); \
|
out[0] = LLVMBuildInsertElement(builder, out[0], \
|
fn(builder, baseVal, this->getType(c.fType), \
|
"cast"), \
|
index, "construct insert"); \
|
} \
|
return true;
|
if (kFloat_TypeKind == to) {
|
if (kInt_TypeKind == from) {
|
CONSTRUCT(LLVMBuildSIToFP);
|
}
|
if (kUInt_TypeKind == from) {
|
CONSTRUCT(LLVMBuildUIToFP);
|
}
|
}
|
if (kInt_TypeKind == to) {
|
if (kFloat_TypeKind == from) {
|
CONSTRUCT(LLVMBuildFPToSI);
|
}
|
if (kUInt_TypeKind == from) {
|
return true;
|
}
|
}
|
if (kUInt_TypeKind == to) {
|
if (kFloat_TypeKind == from) {
|
CONSTRUCT(LLVMBuildFPToUI);
|
}
|
if (kInt_TypeKind == from) {
|
return base;
|
}
|
}
|
printf("%s\n", c.description().c_str());
|
ABORT("unsupported constructor");
|
}
|
case Type::kVector_Kind: {
|
if (c.fArguments.size() == 1) {
|
LLVMValueRef base[CHANNELS];
|
if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
|
return false;
|
}
|
for (int i = 0; i < c.fType.columns(); ++i) {
|
out[i] = base[0];
|
}
|
} else {
|
SkASSERT(c.fArguments.size() == (size_t) c.fType.columns());
|
for (int i = 0; i < c.fType.columns(); ++i) {
|
LLVMValueRef base[CHANNELS];
|
if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) {
|
return false;
|
}
|
out[i] = base[0];
|
}
|
}
|
return true;
|
}
|
default:
|
break;
|
}
|
ABORT("unsupported constructor");
|
}
|
|
bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder,
|
const FloatLiteral& f,
|
LLVMValueRef out[CHANNELS]) {
|
LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue);
|
LLVMValueRef values[MAX_VECTOR_COUNT];
|
for (int i = 0; i < fVectorCount; ++i) {
|
values[i] = value;
|
}
|
out[0] = LLVMConstVector(values, fVectorCount);
|
return true;
|
}
|
|
|
bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
|
LLVMValueRef out[CHANNELS]) {
|
LLVMValueRef all[CHANNELS];
|
if (!this->compileVectorExpression(builder, *s.fBase, all)) {
|
return false;
|
}
|
for (size_t i = 0; i < s.fComponents.size(); ++i) {
|
out[i] = all[s.fComponents[i]];
|
}
|
return true;
|
}
|
|
bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
|
LLVMValueRef out[CHANNELS]) {
|
if (&v.fVariable == fColorParam) {
|
for (int i = 0; i < CHANNELS; ++i) {
|
out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference");
|
}
|
return true;
|
}
|
return false;
|
}
|
|
bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
|
LLVMValueRef out[CHANNELS]) {
|
switch (expr.fKind) {
|
case Expression::kBinary_Kind:
|
return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out);
|
case Expression::kConstructor_Kind:
|
return this->compileVectorConstructor(builder, (const Constructor&) expr, out);
|
case Expression::kFloatLiteral_Kind:
|
return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out);
|
case Expression::kSwizzle_Kind:
|
return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out);
|
case Expression::kVariableReference_Kind:
|
return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
|
out);
|
default:
|
return false;
|
}
|
}
|
|
bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) {
|
switch (stmt.fKind) {
|
case Statement::kBlock_Kind:
|
for (const auto& s : ((const Block&) stmt).fStatements) {
|
if (!this->compileVectorStatement(builder, *s)) {
|
return false;
|
}
|
}
|
return true;
|
case Statement::kExpression_Kind:
|
LLVMValueRef result;
|
return this->compileVectorExpression(builder,
|
*((const ExpressionStatement&) stmt).fExpression,
|
&result);
|
default:
|
return false;
|
}
|
}
|
|
bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) {
|
LLVMValueRef oldFunction = fCurrentFunction;
|
fCurrentFunction = newFunc;
|
std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
|
LLVMGetParams(fCurrentFunction, params.get());
|
LLVMValueRef programParam = params.get()[1];
|
LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
|
LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
|
LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
|
fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
|
this->setBlock(builder, fAllocaBlock);
|
fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
|
LLVMBuildStore(builder, params.get()[4], fChannels[0]);
|
fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
|
LLVMBuildStore(builder, params.get()[5], fChannels[1]);
|
fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
|
LLVMBuildStore(builder, params.get()[6], fChannels[2]);
|
fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
|
LLVMBuildStore(builder, params.get()[7], fChannels[3]);
|
LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
|
this->setBlock(builder, start);
|
bool success = this->compileVectorStatement(builder, *f.fBody);
|
if (success) {
|
// increment program pointer, call next
|
LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
|
LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
|
LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType,
|
"cast next->func");
|
LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
|
LLVMBuildAdd(builder,
|
LLVMBuildPtrToInt(builder,
|
programParam,
|
fInt64Type,
|
"cast 1"),
|
LLVMConstInt(fInt64Type, PTR_SIZE,
|
false),
|
"add"),
|
LLVMPointerType(fInt8PtrType, 0), "cast 2");
|
LLVMValueRef args[STAGE_PARAM_COUNT] = {
|
params.get()[0],
|
nextInc,
|
params.get()[2],
|
params.get()[3],
|
LLVMBuildLoad(builder, fChannels[0], "rVec"),
|
LLVMBuildLoad(builder, fChannels[1], "gVec"),
|
LLVMBuildLoad(builder, fChannels[2], "bVec"),
|
LLVMBuildLoad(builder, fChannels[3], "aVec"),
|
params.get()[8],
|
params.get()[9],
|
params.get()[10],
|
params.get()[11]
|
};
|
LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
|
LLVMBuildRetVoid(builder);
|
// finish
|
LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
|
LLVMBuildBr(builder, start);
|
LLVMDisposeBuilder(builder);
|
if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
|
ABORT("verify failed\n");
|
}
|
} else {
|
LLVMDeleteBasicBlock(fAllocaBlock);
|
LLVMDeleteBasicBlock(start);
|
}
|
|
fAllocaBlock = oldAllocaBlock;
|
fCurrentBlock = oldCurrentBlock;
|
fCurrentFunction = oldFunction;
|
return success;
|
}
|
|
LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) {
|
LLVMTypeRef returnType = fVoidType;
|
LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType,
|
fSizeTType, fFloat32VectorType, fFloat32VectorType,
|
fFloat32VectorType, fFloat32VectorType, fFloat32VectorType,
|
fFloat32VectorType, fFloat32VectorType, fFloat32VectorType };
|
LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false);
|
LLVMValueRef result = LLVMAddFunction(fModule,
|
(String(f.fDeclaration.fName) + "$stage").c_str(),
|
stageFuncType);
|
fColorParam = f.fDeclaration.fParameters[2];
|
if (!this->compileStageFunctionVector(f, result)) {
|
// vectorization failed, fall back to looping over the pixels
|
this->compileStageFunctionLoop(f, result);
|
}
|
return result;
|
}
|
|
bool JIT::hasStageSignature(const FunctionDeclaration& f) {
|
return f.fReturnType == *fProgram->fContext->fVoid_Type &&
|
f.fParameters.size() == 3 &&
|
f.fParameters[0]->fType == *fProgram->fContext->fInt_Type &&
|
f.fParameters[0]->fModifiers.fFlags == 0 &&
|
f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
|
f.fParameters[1]->fModifiers.fFlags == 0 &&
|
f.fParameters[2]->fType == *fProgram->fContext->fHalf4_Type &&
|
f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
|
}
|
|
LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) {
|
if (this->hasStageSignature(f.fDeclaration)) {
|
this->compileStageFunction(f);
|
// we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent
|
// was to produce an SkJumper stage just because the signature matched or that the function
|
// is not otherwise called. May need a better way to handle this.
|
}
|
LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType);
|
std::vector<LLVMTypeRef> parameterTypes;
|
for (const auto& p : f.fDeclaration.fParameters) {
|
LLVMTypeRef type = this->getType(p->fType);
|
if (p->fModifiers.fFlags & Modifiers::kOut_Flag) {
|
type = LLVMPointerType(type, 0);
|
}
|
parameterTypes.push_back(type);
|
}
|
fCurrentFunction = LLVMAddFunction(fModule,
|
String(f.fDeclaration.fName).c_str(),
|
LLVMFunctionType(returnType, parameterTypes.data(),
|
parameterTypes.size(), false));
|
fFunctions[&f.fDeclaration] = fCurrentFunction;
|
|
std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]);
|
LLVMGetParams(fCurrentFunction, params.get());
|
for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) {
|
fVariables[f.fDeclaration.fParameters[i]] = params.get()[i];
|
}
|
LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
|
fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
|
LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
|
fCurrentBlock = start;
|
LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
|
this->compileStatement(builder, *f.fBody);
|
if (!ends_with_branch(*f.fBody)) {
|
if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) {
|
LLVMBuildRetVoid(builder);
|
} else {
|
LLVMBuildUnreachable(builder);
|
}
|
}
|
LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
|
LLVMBuildBr(builder, start);
|
LLVMDisposeBuilder(builder);
|
if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
|
ABORT("verify failed\n");
|
}
|
return fCurrentFunction;
|
}
|
|
void JIT::createModule() {
|
fPromotedParameters.clear();
|
fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
|
this->loadBuiltinFunctions();
|
LLVMTypeRef fold2Params[1] = { fInt1Vector2Type };
|
fFoldAnd2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v2i1",
|
LLVMFunctionType(fInt1Type, fold2Params, 1, false));
|
fFoldOr2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v2i1",
|
LLVMFunctionType(fInt1Type, fold2Params, 1, false));
|
LLVMTypeRef fold3Params[1] = { fInt1Vector3Type };
|
fFoldAnd3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v3i1",
|
LLVMFunctionType(fInt1Type, fold3Params, 1, false));
|
fFoldOr3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v3i1",
|
LLVMFunctionType(fInt1Type, fold3Params, 1, false));
|
LLVMTypeRef fold4Params[1] = { fInt1Vector4Type };
|
fFoldAnd4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v4i1",
|
LLVMFunctionType(fInt1Type, fold4Params, 1, false));
|
fFoldOr4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v4i1",
|
LLVMFunctionType(fInt1Type, fold4Params, 1, false));
|
// LLVM doesn't do void*, have to declare it as int8*
|
LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
|
fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
|
appendParams,
|
3,
|
false));
|
LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType };
|
fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback",
|
LLVMFunctionType(fVoidType, appendCallbackParams, 2,
|
false));
|
|
LLVMTypeRef debugParams[3] = { fFloat32Type };
|
fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType,
|
debugParams,
|
1,
|
false));
|
|
for (const auto& e : *fProgram) {
|
if (e.fKind == ProgramElement::kFunction_Kind) {
|
this->compileFunction((FunctionDefinition&) e);
|
}
|
}
|
}
|
|
std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
|
fCompiler.optimize(*program);
|
fProgram = std::move(program);
|
this->createModule();
|
this->optimize();
|
return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack));
|
}
|
|
void JIT::optimize() {
|
LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate();
|
LLVMPassManagerBuilderSetOptLevel(pmb, 3);
|
LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule);
|
LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM);
|
LLVMPassManagerRef modulePM = LLVMCreatePassManager();
|
LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM);
|
LLVMInitializeFunctionPassManager(functionPM);
|
|
LLVMValueRef func = LLVMGetFirstFunction(fModule);
|
for (;;) {
|
if (!func) {
|
break;
|
}
|
LLVMRunFunctionPassManager(functionPM, func);
|
func = LLVMGetNextFunction(func);
|
}
|
LLVMRunPassManager(modulePM, fModule);
|
LLVMDisposePassManager(functionPM);
|
LLVMDisposePassManager(modulePM);
|
LLVMPassManagerBuilderDispose(pmb);
|
|
std::string error_string;
|
if (LLVMLoadLibraryPermanently(nullptr)) {
|
ABORT("LLVMLoadLibraryPermanently failed");
|
}
|
char* defaultTriple = LLVMGetDefaultTargetTriple();
|
char* error;
|
LLVMTargetRef target;
|
if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) {
|
ABORT("LLVMGetTargetFromTriple failed");
|
}
|
|
if (!LLVMTargetHasJIT(target)) {
|
ABORT("!LLVMTargetHasJIT");
|
}
|
LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target,
|
defaultTriple,
|
fCPU,
|
nullptr,
|
LLVMCodeGenLevelDefault,
|
LLVMRelocDefault,
|
LLVMCodeModelJITDefault);
|
LLVMDisposeMessage(defaultTriple);
|
LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine);
|
LLVMSetModuleDataLayout(fModule, dataLayout);
|
LLVMDisposeTargetData(dataLayout);
|
|
fJITStack = LLVMOrcCreateInstance(targetMachine);
|
fSharedModule = LLVMOrcMakeSharedModule(fModule);
|
LLVMOrcModuleHandle orcModule;
|
LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule,
|
(LLVMOrcSymbolResolverFn) resolveSymbol, this);
|
LLVMDisposeTargetMachine(targetMachine);
|
}
|
|
void* JIT::Module::getSymbol(const char* name) {
|
LLVMOrcTargetAddress result;
|
if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) {
|
ABORT("GetSymbolAddress error");
|
}
|
if (!result) {
|
ABORT("symbol not found");
|
}
|
return (void*) result;
|
}
|
|
void* JIT::Module::getJumperStage(const char* name) {
|
return this->getSymbol((String(name) + "$stage").c_str());
|
}
|
|
} // namespace
|
|
#endif // SK_LLVM_AVAILABLE
|
|
#endif // SKSL_STANDALONE
|