/*
|
* Copyright (C) 2018 The Android Open Source Project
|
*
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*/
|
|
#include "utils/lua-utils.h"
|
|
// lua_dump takes an extra argument "strip" in 5.3, but not in 5.2.
|
#ifndef TC3_AOSP
|
#define lua_dump(L, w, d, s) lua_dump((L), (w), (d))
|
#endif
|
|
namespace libtextclassifier3 {
|
namespace {
|
// Upvalue indices for the flatbuffer callback.
|
static constexpr int kSchemaArgId = 1;
|
static constexpr int kTypeArgId = 2;
|
static constexpr int kTableArgId = 3;
|
|
static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
|
{LUA_TABLIBNAME, luaopen_table},
|
{LUA_STRLIBNAME, luaopen_string},
|
{LUA_BITLIBNAME, luaopen_bit32},
|
{LUA_MATHLIBNAME, luaopen_math},
|
{nullptr, nullptr}};
|
|
// Implementation of a lua_Writer that appends the data to a string.
|
int LuaStringWriter(lua_State *state, const void *data, size_t size,
|
void *result) {
|
std::string *const result_string = static_cast<std::string *>(result);
|
result_string->insert(result_string->size(), static_cast<const char *>(data),
|
size);
|
return LUA_OK;
|
}
|
|
} // namespace
|
|
LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
|
|
LuaEnvironment::~LuaEnvironment() {
|
if (state_ != nullptr) {
|
lua_close(state_);
|
}
|
}
|
|
int LuaEnvironment::Iterator::NextCallback(lua_State *state) {
|
return FromUpValue<Iterator *>(kIteratorArgId, state)->Next(state);
|
}
|
|
int LuaEnvironment::Iterator::LengthCallback(lua_State *state) {
|
return FromUpValue<Iterator *>(kIteratorArgId, state)->Length(state);
|
}
|
|
int LuaEnvironment::Iterator::ItemCallback(lua_State *state) {
|
return FromUpValue<Iterator *>(kIteratorArgId, state)->Item(state);
|
}
|
|
int LuaEnvironment::Iterator::IteritemsCallback(lua_State *state) {
|
return FromUpValue<Iterator *>(kIteratorArgId, state)->Iteritems(state);
|
}
|
|
void LuaEnvironment::PushFlatbuffer(const char *name,
|
const reflection::Schema *schema,
|
const reflection::Object *type,
|
const flatbuffers::Table *table,
|
lua_State *state) {
|
lua_newtable(state);
|
luaL_newmetatable(state, name);
|
lua_pushlightuserdata(state, AsUserData(schema));
|
lua_pushlightuserdata(state, AsUserData(type));
|
lua_pushlightuserdata(state, AsUserData(table));
|
lua_pushcclosure(state, &GetFieldCallback, 3);
|
lua_setfield(state, -2, kIndexKey);
|
lua_setmetatable(state, -2);
|
}
|
|
int LuaEnvironment::GetFieldCallback(lua_State *state) {
|
// Fetch the arguments.
|
const reflection::Schema *schema =
|
FromUpValue<reflection::Schema *>(kSchemaArgId, state);
|
const reflection::Object *type =
|
FromUpValue<reflection::Object *>(kTypeArgId, state);
|
const flatbuffers::Table *table =
|
FromUpValue<flatbuffers::Table *>(kTableArgId, state);
|
return GetField(schema, type, table, state);
|
}
|
|
int LuaEnvironment::GetField(const reflection::Schema *schema,
|
const reflection::Object *type,
|
const flatbuffers::Table *table,
|
lua_State *state) {
|
const char *field_name = lua_tostring(state, -1);
|
const reflection::Field *field = type->fields()->LookupByKey(field_name);
|
if (field == nullptr) {
|
lua_error(state);
|
return 0;
|
}
|
// Provide primitive fields directly.
|
const reflection::BaseType field_type = field->type()->base_type();
|
switch (field_type) {
|
case reflection::Bool:
|
lua_pushboolean(state, table->GetField<uint8_t>(
|
field->offset(), field->default_integer()));
|
break;
|
case reflection::Int:
|
lua_pushinteger(state, table->GetField<int32>(field->offset(),
|
field->default_integer()));
|
break;
|
case reflection::Long:
|
lua_pushinteger(state, table->GetField<int64>(field->offset(),
|
field->default_integer()));
|
break;
|
case reflection::Float:
|
lua_pushnumber(state, table->GetField<float>(field->offset(),
|
field->default_real()));
|
break;
|
case reflection::Double:
|
lua_pushnumber(state, table->GetField<double>(field->offset(),
|
field->default_real()));
|
break;
|
case reflection::String: {
|
const flatbuffers::String *string_value =
|
table->GetPointer<const flatbuffers::String *>(field->offset());
|
if (string_value != nullptr) {
|
lua_pushlstring(state, string_value->data(), string_value->Length());
|
} else {
|
lua_pushlstring(state, "", 0);
|
}
|
break;
|
}
|
case reflection::Obj: {
|
const flatbuffers::Table *field_table =
|
table->GetPointer<const flatbuffers::Table *>(field->offset());
|
if (field_table == nullptr) {
|
TC3_LOG(ERROR) << "Field was not set in entity data.";
|
lua_error(state);
|
return 0;
|
}
|
const reflection::Object *field_type =
|
schema->objects()->Get(field->type()->index());
|
PushFlatbuffer(field->name()->c_str(), schema, field_type, field_table,
|
state);
|
break;
|
}
|
default:
|
TC3_LOG(ERROR) << "Unsupported type: " << field_type;
|
lua_error(state);
|
return 0;
|
}
|
return 1;
|
}
|
|
int LuaEnvironment::ReadFlatbuffer(ReflectiveFlatbuffer *buffer) {
|
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
|
TC3_LOG(ERROR) << "Expected actions table, got: "
|
<< lua_type(state_, /*idx=*/-1);
|
lua_error(state_);
|
return LUA_ERRRUN;
|
}
|
|
lua_pushnil(state_);
|
while (lua_next(state_, /*idx=*/-2)) {
|
const StringPiece key = ReadString(/*index=*/-2);
|
const reflection::Field *field = buffer->GetFieldOrNull(key);
|
if (field == nullptr) {
|
TC3_LOG(ERROR) << "Unknown field: " << key.ToString();
|
lua_error(state_);
|
return LUA_ERRRUN;
|
}
|
switch (field->type()->base_type()) {
|
case reflection::Obj:
|
return ReadFlatbuffer(buffer->Mutable(field));
|
case reflection::Bool:
|
buffer->Set(field,
|
static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
|
break;
|
case reflection::Int:
|
buffer->Set(field, static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
|
break;
|
case reflection::Long:
|
buffer->Set(field,
|
static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
|
break;
|
case reflection::Float:
|
buffer->Set(field,
|
static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
|
break;
|
case reflection::Double:
|
buffer->Set(field,
|
static_cast<double>(lua_tonumber(state_, /*idx=*/-1)));
|
break;
|
case reflection::String: {
|
buffer->Set(field, ReadString(/*index=*/-1));
|
break;
|
}
|
default:
|
TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
|
lua_error(state_);
|
return LUA_ERRRUN;
|
}
|
lua_pop(state_, 1);
|
}
|
// lua_pop(state_, /*n=*/1);
|
return LUA_OK;
|
}
|
|
void LuaEnvironment::LoadDefaultLibraries() {
|
for (const luaL_Reg *lib = defaultlibs; lib->func; lib++) {
|
luaL_requiref(state_, lib->name, lib->func, 1);
|
lua_pop(state_, 1); /* remove lib */
|
}
|
}
|
|
void LuaEnvironment::PushValue(const Variant &value) {
|
if (value.HasInt()) {
|
lua_pushnumber(state_, value.IntValue());
|
} else if (value.HasInt64()) {
|
lua_pushnumber(state_, value.Int64Value());
|
} else if (value.HasBool()) {
|
lua_pushboolean(state_, value.BoolValue());
|
} else if (value.HasFloat()) {
|
lua_pushnumber(state_, value.FloatValue());
|
} else if (value.HasDouble()) {
|
lua_pushnumber(state_, value.DoubleValue());
|
} else if (value.HasString()) {
|
lua_pushlstring(state_, value.StringValue().data(),
|
value.StringValue().size());
|
} else {
|
TC3_LOG(FATAL) << "Unknown value type.";
|
}
|
}
|
|
StringPiece LuaEnvironment::ReadString(const int index) const {
|
size_t length = 0;
|
const char *data = lua_tolstring(state_, index, &length);
|
return StringPiece(data, length);
|
}
|
|
void LuaEnvironment::PushString(const StringPiece str) {
|
lua_pushlstring(state_, str.data(), str.size());
|
}
|
|
void LuaEnvironment::PushFlatbuffer(const reflection::Schema *schema,
|
const flatbuffers::Table *table) {
|
PushFlatbuffer(schema->root_table()->name()->c_str(), schema,
|
schema->root_table(), table, state_);
|
}
|
|
int LuaEnvironment::RunProtected(const std::function<int()> &func,
|
const int num_args, const int num_results) {
|
struct ProtectedCall {
|
std::function<int()> func;
|
|
static int run(lua_State *state) {
|
// Read the pointer to the ProtectedCall struct.
|
ProtectedCall *p = static_cast<ProtectedCall *>(
|
lua_touserdata(state, lua_upvalueindex(1)));
|
return p->func();
|
}
|
};
|
ProtectedCall protected_call = {func};
|
lua_pushlightuserdata(state_, &protected_call);
|
lua_pushcclosure(state_, &ProtectedCall::run, /*n=*/1);
|
// Put the closure before the arguments on the stack.
|
if (num_args > 0) {
|
lua_insert(state_, -(1 + num_args));
|
}
|
return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
|
}
|
|
bool LuaEnvironment::Compile(StringPiece snippet, std::string *bytecode) {
|
if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
|
/*name=*/nullptr) != LUA_OK) {
|
TC3_LOG(ERROR) << "Could not compile lua snippet: "
|
<< ReadString(/*index=*/-1).ToString();
|
lua_pop(state_, 1);
|
return false;
|
}
|
if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
|
TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
|
lua_pop(state_, 1);
|
return false;
|
}
|
lua_pop(state_, 1);
|
return true;
|
}
|
|
bool Compile(StringPiece snippet, std::string *bytecode) {
|
return LuaEnvironment().Compile(snippet, bytecode);
|
}
|
|
} // namespace libtextclassifier3
|