Continued pytorch fixes

This commit is contained in:
William S. Moses 2021-12-05 23:14:55 -05:00 committed by William Moses
parent 70702fa37f
commit a7d383d6ab
3 changed files with 57 additions and 26 deletions

View File

@ -169,6 +169,10 @@ void ValueCategory::store(mlir::OpBuilder &builder, ValueCategory toStore,
mlir::Type elty;
if (auto at = pt.getElementType().dyn_cast<LLVM::LLVMArrayType>()) {
elty = at.getElementType();
if (smt.getShape().back() != at.getNumElements()) {
llvm::errs() << " at: " << at << " smt: " << smt << "\n";
llvm::errs() << " val: " << val << " val.isRef: " << isReference << " ts: " << toStore.val << " ts.isRef: " << toStore.isReference << " isArray: " << isArray << "\n";
}
assert(smt.getShape().back() == at.getNumElements());
} else {
auto st = pt.getElementType().dyn_cast<LLVM::LLVMStructType>();

View File

@ -399,6 +399,14 @@ ValueCategory MLIRScanner::VisitConstantExpr(clang::ConstantExpr *expr) {
return Visit(expr->getSubExpr());
}
ValueCategory MLIRScanner::VisitTypeTraitExpr(clang::TypeTraitExpr *expr) {
auto ty = getMLIRType(expr->getType()).cast<mlir::IntegerType>();
return ValueCategory(
builder.create<ConstantIntOp>(getMLIRLocation(expr->getExprLoc()),
expr->getValue(), ty),
/*isReference*/ false);
}
ValueCategory MLIRScanner::VisitIntegerLiteral(clang::IntegerLiteral *expr) {
auto ty = getMLIRType(expr->getType()).cast<mlir::IntegerType>();
return ValueCategory(
@ -735,7 +743,7 @@ ValueCategory MLIRScanner::VisitArrayInitLoop(clang::ArrayInitLoopExpr *expr,
arrayinit.push_back(affineOp.getInductionVar());
auto alu = CommonArrayLookup(CommonArrayToPointer(tostore),
affineOp.getInductionVar());
affineOp.getInductionVar(), /*isImplicitRef*/false);
if (auto AILE = dyn_cast<ArrayInitLoopExpr>(expr->getSubExpr())) {
VisitArrayInitLoop(AILE, alu);
@ -1276,7 +1284,7 @@ ValueCategory MLIRScanner::CommonArrayToPointer(ValueCategory scalar) {
}
ValueCategory MLIRScanner::CommonArrayLookup(ValueCategory array,
mlir::Value idx) {
mlir::Value idx, bool isImplicitRefResult) {
mlir::Value val = array.getValue(builder);
assert(val);
@ -1314,6 +1322,7 @@ ValueCategory MLIRScanner::CommonArrayLookup(ValueCategory array,
auto mt = dref.val.getType().cast<MemRefType>();
auto shape = std::vector<int64_t>(mt.getShape());
if (shape.size() > 1) {
// if (shape.size() > 2 || (shape.size() > 1 && !isImplicitRefResult)) {
shape.erase(shape.begin());
} else {
shape[0] = -1;
@ -1348,7 +1357,9 @@ MLIRScanner::VisitArraySubscriptExpr(clang::ArraySubscriptExpr *expr) {
moo.val = builder.create<polygeist::SubIndexOp>(loc, mt0, moo.val,
getConstantIndex(0));
}
return CommonArrayLookup(moo, idx);
bool isArray = false;
Glob.getMLIRType(expr->getType(), &isArray);
return CommonArrayLookup(moo, idx, isArray);
}
bool isRecursiveStruct(llvm::Type *T, llvm::Type *Meta,
@ -1912,24 +1923,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
#endif
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
if (sr->getDecl()->getIdentifier() &&
(sr->getDecl()->getName() == "fprintf" ||
sr->getDecl()->getName() == "printf")) {
auto tocall = EmitCallee(expr->getCallee());
auto fprintfF = Glob.GetOrCreateLLVMFunction(tocall);
std::vector<mlir::Value> args;
size_t i = 0;
for (auto a : expr->arguments()) {
args.push_back(getLLVM(a));
i++;
}
builder.create<mlir::LLVM::CallOp>(loc, fprintfF, args);
return nullptr;
}
if (sr->getDecl()->getIdentifier() &&
(sr->getDecl()->getName() == "__nv_fabsf" ||
sr->getDecl()->getName() == "__nv_fabs" ||
@ -2434,11 +2427,20 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
builder.setInsertionPoint(oldblock, oldpoint);
auto retTy = getMLIRType(expr->getType());
if (sr->getDecl()->getName() == "__builtin_memcpy")
if (sr->getDecl()->getName() == "__builtin_memcpy" || retTy.isa<LLVM::LLVMPointerType>()) {
if (dst.getType().isa<MemRefType>())
dst = builder.create<polygeist::Memref2PointerOp>(loc, retTy, dst);
assert(dst.getType() == retTy);
return ValueCategory(dst, /*isReference*/ false);
else
}
else {
if (!retTy.isa<mlir::IntegerType>()) {
expr->dump();
llvm::errs() << " retTy: " << retTy << "\n";
}
return ValueCategory(builder.create<ConstantIntOp>(loc, 0, retTy),
/*isReference*/ false);
}
}
}
/*
@ -2537,7 +2539,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
std::set<std::string> funcs = {
"strcmp",
"sprintf",
"fputs",
"puts",
"memcpy",
@ -2545,6 +2546,7 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
"getenv",
"strrchr",
"mkdir",
"printf",
"fprintf",
"sprintf",
"fwrite",
@ -4532,6 +4534,29 @@ ValueCategory MLIRScanner::VisitCastExpr(CastExpr *E) {
}
return ValueCategory(res, /*isReference*/ false);
}
case clang::CastKind::CK_FloatingToBoolean: {
auto res = Visit(E->getSubExpr()).getValue(builder);
auto prevTy = res.getType().cast<mlir::FloatType>();
auto postTy = getMLIRType(E->getType()).cast<mlir::IntegerType>();
bool signedType = true;
if (auto bit = dyn_cast<clang::BuiltinType>(&*E->getType())) {
if (bit->isUnsignedInteger())
signedType = false;
if (bit->isSignedInteger())
signedType = true;
}
auto Zero = builder.create<ConstantFloatOp>(
loc, APFloat::getZero(prevTy.getFloatSemantics()), prevTy);
res = builder.create<arith::CmpFOp>(loc, CmpFPredicate::UNE, res, Zero);
if (1 < postTy.getWidth()) {
if (signedType) {
res = builder.create<ExtSIOp>(loc, res, postTy);
} else {
res = builder.create<ExtUIOp>(loc, res, postTy);
}
}
return ValueCategory(res, /*isReference*/ false);
}
case clang::CastKind::CK_IntegralToPointer: {
auto res = Visit(E->getSubExpr()).getValue(builder);
auto postTy = getMLIRType(E->getType()).cast<LLVM::LLVMPointerType>();
@ -5097,7 +5122,7 @@ mlir::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction(const FunctionDecl *FD) {
FunctionDecl::TemplatedKind::
TK_DependentFunctionTemplateSpecialization);
functionsToEmit.push_back(Def);
} else if (FD->getIdentifier()) {
} else {
emitIfFound.insert(name);
}
assert(function->getParentOp() == module.get());

View File

@ -225,6 +225,8 @@ public:
ValueCategory VisitConstantExpr(clang::ConstantExpr *expr);
ValueCategory VisitTypeTraitExpr(clang::TypeTraitExpr *expr);
ValueCategory VisitIntegerLiteral(clang::IntegerLiteral *expr);
ValueCategory VisitCharacterLiteral(clang::CharacterLiteral *expr);
@ -343,7 +345,7 @@ public:
ValueCategory CommonFieldLookup(clang::QualType OT, const FieldDecl *FD,
mlir::Value val, bool isLValue);
ValueCategory CommonArrayLookup(ValueCategory val, mlir::Value idx);
ValueCategory CommonArrayLookup(ValueCategory val, mlir::Value idx, bool isImplicitRefResult);
ValueCategory CommonArrayToPointer(ValueCategory val);
};