lin
2025-03-11 6f4f7a76e03a46fefb056a4b18197f1d9e8aa939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
/*
 * Copyright 2015, The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
 
#include "slang_rs_foreach_lowering.h"
 
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "llvm/Support/raw_ostream.h"
#include "slang_rs_context.h"
#include "slang_rs_export_foreach.h"
 
namespace slang {
 
namespace {
 
const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsForEach";
const char KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS[] = "rsForEachWithOptions";
const char INTERNAL_LAUNCH_FUNCTION_NAME[] =
    "_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation";
 
}  // anonymous namespace
 
RSForEachLowering::RSForEachLowering(RSContext* ctxt)
    : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {}
 
// Check if the passed-in expr references a kernel function in the following
// pattern in the AST.
//
// ImplicitCastExpr 'void *' <BitCast>
//  `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
//    `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator(
    clang::Expr* expr) {
  clang::ImplicitCastExpr* ToVoidPtr =
      clang::dyn_cast<clang::ImplicitCastExpr>(expr);
  if (ToVoidPtr == nullptr) {
    return nullptr;
  }
 
  clang::ImplicitCastExpr* Decay =
      clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr());
 
  if (Decay == nullptr) {
    return nullptr;
  }
 
  clang::DeclRefExpr* DRE =
      clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr());
 
  if (DRE == nullptr) {
    return nullptr;
  }
 
  const clang::FunctionDecl* FD =
      clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl());
 
  if (FD == nullptr) {
    return nullptr;
  }
 
  return FD;
}
 
// Checks if the call expression is a legal rsForEach call by looking for the
// following pattern in the AST. On success, returns the first argument that is
// a FunctionDecl of a kernel function.
//
// CallExpr 'void'
// |
// |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay>
// | `-DeclRefExpr  'void (void *, ...)'  'rsForEach' 'void (void *, ...)'
// |
// |-ImplicitCastExpr 'void *' <BitCast>
// | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
// |   `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
// |
// |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
// | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation'
// |
// `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
//   `-DeclRefExpr  'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation'
const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall(
    clang::CallExpr* CE, int* slot, bool* hasOptions) {
  const clang::Decl* D = CE->getCalleeDecl();
  const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D);
 
  if (FD == nullptr) {
    return nullptr;
  }
 
  const clang::StringRef& funcName = FD->getName();
 
  if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) {
    *hasOptions = false;
  } else if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS)) {
    *hasOptions = true;
  } else {
    return nullptr;
  }
 
  if (mInsideKernel) {
    mCtxt->ReportError(CE->getExprLoc(),
        "Invalid kernel launch call made from inside another kernel.");
    return nullptr;
  }
 
  clang::Expr* arg0 = CE->getArg(0);
  const clang::FunctionDecl* kernel = matchFunctionDesignator(arg0);
 
  if (kernel == nullptr) {
    mCtxt->ReportError(arg0->getExprLoc(),
                       "Invalid kernel launch call. "
                       "Expects a function designator for the first argument.");
    return nullptr;
  }
 
  // Verifies that kernel is indeed a "kernel" function.
  *slot = mCtxt->getForEachSlotNumber(kernel);
  if (*slot == -1) {
    mCtxt->ReportError(CE->getExprLoc(),
         "%0 applied to function %1 defined without \"kernel\" attribute")
         << funcName << kernel->getName();
    return nullptr;
  }
 
  return kernel;
}
 
// Create an AST node for the declaration of rsForEachInternal
clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() {
  clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl();
  clang::SourceLocation Loc;
 
  llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME);
  clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR);
  clang::DeclarationName N(&II);
 
  clang::FunctionProtoType::ExtProtoInfo EPI;
 
  const clang::QualType& AllocTy = mCtxt->getAllocationType();
  clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
 
  clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
  const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
 
  clang::QualType ParamTypes[] = {
    mASTCtxt.IntTy,   // int slot
    ScriptCallPtrTy,  // rs_script_call_t* launch_options
    mASTCtxt.IntTy,   // int numOutput
    mASTCtxt.IntTy,   // int numInputs
    AllocPtrTy        // rs_allocation* allocs
  };
 
  clang::QualType T = mASTCtxt.getFunctionType(
      mASTCtxt.VoidTy,  // Return type
      ParamTypes,       // Parameter types
      EPI);
 
  clang::FunctionDecl* FD = clang::FunctionDecl::Create(
      mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern);
 
  static constexpr unsigned kNumParams = sizeof(ParamTypes) / sizeof(ParamTypes[0]);
  clang::ParmVarDecl *ParamDecls[kNumParams];
  for (unsigned I = 0; I != kNumParams; ++I) {
    ParamDecls[I] = clang::ParmVarDecl::Create(mASTCtxt, FD, Loc,
        Loc, nullptr, ParamTypes[I], nullptr, clang::SC_None, nullptr);
    // Implicit means that this declaration was created by the compiler, and
    // not part of the actual source code.
    ParamDecls[I]->setImplicit();
  }
  FD->setParams(llvm::makeArrayRef(ParamDecls, kNumParams));
 
  // Implicit means that this declaration was created by the compiler, and
  // not part of the actual source code.
  FD->setImplicit();
 
  return FD;
}
 
// Create an expression like the following that references the rsForEachInternal to
// replace the callee in the original call expression that references rsForEach.
//
// ImplicitCastExpr 'void (*)(int, rs_script_call_t*, int, int, rs_allocation*)' <FunctionToPointerDecay>
// `-DeclRefExpr 'void' Function '_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation' 'void (int, rs_script_call_t*, int, int, rs_allocation*)'
clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() {
  clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl();
 
  const clang::QualType FDNewType = FDNew->getType();
 
  clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create(
      mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew,
      false, clang::SourceLocation(), FDNewType, clang::VK_RValue);
 
  clang::Expr* calleeNew = clang::ImplicitCastExpr::Create(
      mASTCtxt, mASTCtxt.getPointerType(FDNewType),
      clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue);
 
  return calleeNew;
}
 
// This visit method checks (via pattern matching) if the call expression is to
// rsForEach, and the arguments satisfy the restrictions on the
// rsForEach API. If so, replace the call with a rsForEachInternal call
// with the first argument replaced by the slot number of the kernel function
// referenced in the original first argument.
//
// See comments to the helper methods defined above for details.
void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) {
  int slot;
  bool hasOptions;
  const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE, &slot, &hasOptions);
  if (kernel == nullptr) {
    return;
  }
 
  slangAssert(slot >= 0);
 
  const unsigned numArgsOrig = CE->getNumArgs();
 
  clang::QualType resultType = kernel->getReturnType().getCanonicalType();
  const unsigned numOutputsExpected = resultType->isVoidType() ? 0 : 1;
 
  const unsigned numInputsExpected = RSExportForEach::getNumInputs(mCtxt->getTargetAPI(), kernel);
 
  // Verifies that rsForEach takes the right number of input and output allocations.
  // TODO: Check input/output allocation types match kernel function expectation.
  const unsigned numAllocations = numArgsOrig - (hasOptions ? 2 : 1);
  if (numInputsExpected + numOutputsExpected != numAllocations) {
    mCtxt->ReportError(
      CE->getExprLoc(),
      "Number of input and output allocations unexpected for kernel function %0")
    << kernel->getName();
    return;
  }
 
  clang::Expr* calleeNew = CreateCalleeExprForInternalForEach();
  CE->setCallee(calleeNew);
 
  const clang::CanQualType IntTy = mASTCtxt.IntTy;
  const unsigned IntTySize = mASTCtxt.getTypeSize(IntTy);
  const llvm::APInt APIntSlot(IntTySize, slot);
  const clang::Expr* arg0 = CE->getArg(0);
  const clang::SourceLocation Loc(arg0->getLocStart());
  clang::Expr* IntSlotNum =
      clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, IntTy, Loc);
  CE->setArg(0, IntSlotNum);
 
  /*
    The last few arguments to rsForEach or rsForEachWithOptions are allocations.
    Creates a new compound literal of an array initialized with those values, and
    passes it to rsForEachInternal as the last (the 5th) argument.
 
    For example, rsForEach(foo, ain1, ain2, aout) would be translated into
    rsForEachInternal(
        1,                                   // Slot number for kernel
        NULL,                                // Launch options
        2,                                   // Number of input allocations
        1,                                   // Number of output allocations
        (rs_allocation[]){ain1, ain2, aout)  // Input and output allocations
    );
 
    The AST for the rs_allocation array looks like following:
 
    ImplicitCastExpr 0x99575670 'struct rs_allocation *' <ArrayToPointerDecay>
    `-CompoundLiteralExpr 0x99575648 'struct rs_allocation [3]' lvalue
      `-InitListExpr 0x99575590 'struct rs_allocation [3]'
      |-ImplicitCastExpr 0x99574b38 'rs_allocation':'struct rs_allocation' <LValueToRValue>
      | `-DeclRefExpr 0x99574a08 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c408 'ain1' 'rs_allocation':'struct rs_allocation'
      |-ImplicitCastExpr 0x99574b50 'rs_allocation':'struct rs_allocation' <LValueToRValue>
      | `-DeclRefExpr 0x99574a30 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'ain2' 'rs_allocation':'struct rs_allocation'
      `-ImplicitCastExpr 0x99574b68 'rs_allocation':'struct rs_allocation' <LValueToRValue>
        `-DeclRefExpr 0x99574a58 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'aout' 'rs_allocation':'struct rs_allocation'
  */
 
  const clang::QualType& AllocTy = mCtxt->getAllocationType();
  const llvm::APInt APIntNumAllocs(IntTySize, numAllocations);
  clang::QualType AllocArrayTy = mASTCtxt.getConstantArrayType(
      AllocTy,
      APIntNumAllocs,
      clang::ArrayType::ArraySizeModifier::Normal,
      0  // index type qualifiers
  );
 
  const int allocArgIndexEnd = numArgsOrig - 1;
  int allocArgIndexStart = allocArgIndexEnd;
 
  clang::Expr** args = CE->getArgs();
 
  clang::SourceLocation lparenloc;
  clang::SourceLocation rparenloc;
 
  if (numAllocations > 0) {
    allocArgIndexStart = hasOptions ? 2 : 1;
    lparenloc = args[allocArgIndexStart]->getExprLoc();
    rparenloc = args[allocArgIndexEnd]->getExprLoc();
  }
 
  clang::InitListExpr* init = new (mASTCtxt) clang::InitListExpr(
      mASTCtxt,
      lparenloc,
      llvm::ArrayRef<clang::Expr*>(args + allocArgIndexStart, numAllocations),
      rparenloc);
  init->setType(AllocArrayTy);
 
  clang::TypeSourceInfo* ti = mASTCtxt.getTrivialTypeSourceInfo(AllocArrayTy);
  clang::CompoundLiteralExpr* CLE = new (mASTCtxt) clang::CompoundLiteralExpr(
      lparenloc,
      ti,
      AllocArrayTy,
      clang::VK_LValue,  // A compound literal is an l-value in C.
      init,
      false  // Not file scope
  );
 
  const clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
 
  clang::ImplicitCastExpr* Decay = clang::ImplicitCastExpr::Create(
      mASTCtxt,
      AllocPtrTy,
      clang::CK_ArrayToPointerDecay,
      CLE,
      nullptr,  // C++ cast path
      clang::VK_RValue
  );
 
  CE->setNumArgs(mASTCtxt, 5);
 
  CE->setArg(4, Decay);
 
  // Sets the new arguments for NULL launch option (if the user does not set one),
  // the number of outputs, and the number of inputs.
 
  if (!hasOptions) {
    const llvm::APInt APIntZero(IntTySize, 0);
    clang::Expr* IntNull =
        clang::IntegerLiteral::Create(mASTCtxt, APIntZero, IntTy, Loc);
    clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
    const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
    clang::CStyleCastExpr* Cast =
        clang::CStyleCastExpr::Create(mASTCtxt,
                                      ScriptCallPtrTy,
                                      clang::VK_RValue,
                                      clang::CK_NullToPointer,
                                      IntNull,
                                      nullptr,
                                      mASTCtxt.getTrivialTypeSourceInfo(ScriptCallPtrTy),
                                      clang::SourceLocation(),
                                      clang::SourceLocation());
    CE->setArg(1, Cast);
  }
 
  const llvm::APInt APIntNumOutput(IntTySize, numOutputsExpected);
  clang::Expr* IntNumOutput =
      clang::IntegerLiteral::Create(mASTCtxt, APIntNumOutput, IntTy, Loc);
  CE->setArg(2, IntNumOutput);
 
  const llvm::APInt APIntNumInputs(IntTySize, numInputsExpected);
  clang::Expr* IntNumInputs =
      clang::IntegerLiteral::Create(mASTCtxt, APIntNumInputs, IntTy, Loc);
  CE->setArg(3, IntNumInputs);
}
 
void RSForEachLowering::VisitStmt(clang::Stmt* S) {
  for (clang::Stmt* Child : S->children()) {
    if (Child) {
      Visit(Child);
    }
  }
}
 
void RSForEachLowering::handleForEachCalls(clang::FunctionDecl* FD,
                                           unsigned int targetAPI) {
  slangAssert(FD && FD->hasBody());
 
  mInsideKernel = FD->hasAttr<clang::RenderScriptKernelAttr>();
  VisitStmt(FD->getBody());
}
 
}  // namespace slang