Continued pytorch fixes
This commit is contained in:
parent
70702fa37f
commit
a7d383d6ab
|
@ -169,6 +169,10 @@ void ValueCategory::store(mlir::OpBuilder &builder, ValueCategory toStore,
|
|||
mlir::Type elty;
|
||||
if (auto at = pt.getElementType().dyn_cast<LLVM::LLVMArrayType>()) {
|
||||
elty = at.getElementType();
|
||||
if (smt.getShape().back() != at.getNumElements()) {
|
||||
llvm::errs() << " at: " << at << " smt: " << smt << "\n";
|
||||
llvm::errs() << " val: " << val << " val.isRef: " << isReference << " ts: " << toStore.val << " ts.isRef: " << toStore.isReference << " isArray: " << isArray << "\n";
|
||||
}
|
||||
assert(smt.getShape().back() == at.getNumElements());
|
||||
} else {
|
||||
auto st = pt.getElementType().dyn_cast<LLVM::LLVMStructType>();
|
||||
|
|
|
@ -399,6 +399,14 @@ ValueCategory MLIRScanner::VisitConstantExpr(clang::ConstantExpr *expr) {
|
|||
return Visit(expr->getSubExpr());
|
||||
}
|
||||
|
||||
ValueCategory MLIRScanner::VisitTypeTraitExpr(clang::TypeTraitExpr *expr) {
|
||||
auto ty = getMLIRType(expr->getType()).cast<mlir::IntegerType>();
|
||||
return ValueCategory(
|
||||
builder.create<ConstantIntOp>(getMLIRLocation(expr->getExprLoc()),
|
||||
expr->getValue(), ty),
|
||||
/*isReference*/ false);
|
||||
}
|
||||
|
||||
ValueCategory MLIRScanner::VisitIntegerLiteral(clang::IntegerLiteral *expr) {
|
||||
auto ty = getMLIRType(expr->getType()).cast<mlir::IntegerType>();
|
||||
return ValueCategory(
|
||||
|
@ -735,7 +743,7 @@ ValueCategory MLIRScanner::VisitArrayInitLoop(clang::ArrayInitLoopExpr *expr,
|
|||
arrayinit.push_back(affineOp.getInductionVar());
|
||||
|
||||
auto alu = CommonArrayLookup(CommonArrayToPointer(tostore),
|
||||
affineOp.getInductionVar());
|
||||
affineOp.getInductionVar(), /*isImplicitRef*/false);
|
||||
|
||||
if (auto AILE = dyn_cast<ArrayInitLoopExpr>(expr->getSubExpr())) {
|
||||
VisitArrayInitLoop(AILE, alu);
|
||||
|
@ -1276,7 +1284,7 @@ ValueCategory MLIRScanner::CommonArrayToPointer(ValueCategory scalar) {
|
|||
}
|
||||
|
||||
ValueCategory MLIRScanner::CommonArrayLookup(ValueCategory array,
|
||||
mlir::Value idx) {
|
||||
mlir::Value idx, bool isImplicitRefResult) {
|
||||
mlir::Value val = array.getValue(builder);
|
||||
assert(val);
|
||||
|
||||
|
@ -1314,6 +1322,7 @@ ValueCategory MLIRScanner::CommonArrayLookup(ValueCategory array,
|
|||
auto mt = dref.val.getType().cast<MemRefType>();
|
||||
auto shape = std::vector<int64_t>(mt.getShape());
|
||||
if (shape.size() > 1) {
|
||||
// if (shape.size() > 2 || (shape.size() > 1 && !isImplicitRefResult)) {
|
||||
shape.erase(shape.begin());
|
||||
} else {
|
||||
shape[0] = -1;
|
||||
|
@ -1348,7 +1357,9 @@ MLIRScanner::VisitArraySubscriptExpr(clang::ArraySubscriptExpr *expr) {
|
|||
moo.val = builder.create<polygeist::SubIndexOp>(loc, mt0, moo.val,
|
||||
getConstantIndex(0));
|
||||
}
|
||||
return CommonArrayLookup(moo, idx);
|
||||
bool isArray = false;
|
||||
Glob.getMLIRType(expr->getType(), &isArray);
|
||||
return CommonArrayLookup(moo, idx, isArray);
|
||||
}
|
||||
|
||||
bool isRecursiveStruct(llvm::Type *T, llvm::Type *Meta,
|
||||
|
@ -1912,24 +1923,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
#endif
|
||||
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
|
||||
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
(sr->getDecl()->getName() == "fprintf" ||
|
||||
sr->getDecl()->getName() == "printf")) {
|
||||
|
||||
auto tocall = EmitCallee(expr->getCallee());
|
||||
auto fprintfF = Glob.GetOrCreateLLVMFunction(tocall);
|
||||
|
||||
std::vector<mlir::Value> args;
|
||||
size_t i = 0;
|
||||
for (auto a : expr->arguments()) {
|
||||
args.push_back(getLLVM(a));
|
||||
i++;
|
||||
}
|
||||
|
||||
builder.create<mlir::LLVM::CallOp>(loc, fprintfF, args);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
(sr->getDecl()->getName() == "__nv_fabsf" ||
|
||||
sr->getDecl()->getName() == "__nv_fabs" ||
|
||||
|
@ -2434,11 +2427,20 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
builder.setInsertionPoint(oldblock, oldpoint);
|
||||
|
||||
auto retTy = getMLIRType(expr->getType());
|
||||
if (sr->getDecl()->getName() == "__builtin_memcpy")
|
||||
if (sr->getDecl()->getName() == "__builtin_memcpy" || retTy.isa<LLVM::LLVMPointerType>()) {
|
||||
if (dst.getType().isa<MemRefType>())
|
||||
dst = builder.create<polygeist::Memref2PointerOp>(loc, retTy, dst);
|
||||
assert(dst.getType() == retTy);
|
||||
return ValueCategory(dst, /*isReference*/ false);
|
||||
else
|
||||
}
|
||||
else {
|
||||
if (!retTy.isa<mlir::IntegerType>()) {
|
||||
expr->dump();
|
||||
llvm::errs() << " retTy: " << retTy << "\n";
|
||||
}
|
||||
return ValueCategory(builder.create<ConstantIntOp>(loc, 0, retTy),
|
||||
/*isReference*/ false);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*
|
||||
|
@ -2537,7 +2539,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
|
||||
std::set<std::string> funcs = {
|
||||
"strcmp",
|
||||
"sprintf",
|
||||
"fputs",
|
||||
"puts",
|
||||
"memcpy",
|
||||
|
@ -2545,6 +2546,7 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
"getenv",
|
||||
"strrchr",
|
||||
"mkdir",
|
||||
"printf",
|
||||
"fprintf",
|
||||
"sprintf",
|
||||
"fwrite",
|
||||
|
@ -4532,6 +4534,29 @@ ValueCategory MLIRScanner::VisitCastExpr(CastExpr *E) {
|
|||
}
|
||||
return ValueCategory(res, /*isReference*/ false);
|
||||
}
|
||||
case clang::CastKind::CK_FloatingToBoolean: {
|
||||
auto res = Visit(E->getSubExpr()).getValue(builder);
|
||||
auto prevTy = res.getType().cast<mlir::FloatType>();
|
||||
auto postTy = getMLIRType(E->getType()).cast<mlir::IntegerType>();
|
||||
bool signedType = true;
|
||||
if (auto bit = dyn_cast<clang::BuiltinType>(&*E->getType())) {
|
||||
if (bit->isUnsignedInteger())
|
||||
signedType = false;
|
||||
if (bit->isSignedInteger())
|
||||
signedType = true;
|
||||
}
|
||||
auto Zero = builder.create<ConstantFloatOp>(
|
||||
loc, APFloat::getZero(prevTy.getFloatSemantics()), prevTy);
|
||||
res = builder.create<arith::CmpFOp>(loc, CmpFPredicate::UNE, res, Zero);
|
||||
if (1 < postTy.getWidth()) {
|
||||
if (signedType) {
|
||||
res = builder.create<ExtSIOp>(loc, res, postTy);
|
||||
} else {
|
||||
res = builder.create<ExtUIOp>(loc, res, postTy);
|
||||
}
|
||||
}
|
||||
return ValueCategory(res, /*isReference*/ false);
|
||||
}
|
||||
case clang::CastKind::CK_IntegralToPointer: {
|
||||
auto res = Visit(E->getSubExpr()).getValue(builder);
|
||||
auto postTy = getMLIRType(E->getType()).cast<LLVM::LLVMPointerType>();
|
||||
|
@ -5097,7 +5122,7 @@ mlir::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction(const FunctionDecl *FD) {
|
|||
FunctionDecl::TemplatedKind::
|
||||
TK_DependentFunctionTemplateSpecialization);
|
||||
functionsToEmit.push_back(Def);
|
||||
} else if (FD->getIdentifier()) {
|
||||
} else {
|
||||
emitIfFound.insert(name);
|
||||
}
|
||||
assert(function->getParentOp() == module.get());
|
||||
|
|
|
@ -225,6 +225,8 @@ public:
|
|||
|
||||
ValueCategory VisitConstantExpr(clang::ConstantExpr *expr);
|
||||
|
||||
ValueCategory VisitTypeTraitExpr(clang::TypeTraitExpr *expr);
|
||||
|
||||
ValueCategory VisitIntegerLiteral(clang::IntegerLiteral *expr);
|
||||
|
||||
ValueCategory VisitCharacterLiteral(clang::CharacterLiteral *expr);
|
||||
|
@ -343,7 +345,7 @@ public:
|
|||
ValueCategory CommonFieldLookup(clang::QualType OT, const FieldDecl *FD,
|
||||
mlir::Value val, bool isLValue);
|
||||
|
||||
ValueCategory CommonArrayLookup(ValueCategory val, mlir::Value idx);
|
||||
ValueCategory CommonArrayLookup(ValueCategory val, mlir::Value idx, bool isImplicitRefResult);
|
||||
|
||||
ValueCategory CommonArrayToPointer(ValueCategory val);
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue