[clang][Interp] Implement complex division (#94892)

Share the implementation with the current interpreter.
This commit is contained in:
Timm Baeder 2024-06-18 13:49:02 +02:00 committed by GitHub
parent 69753aa43b
commit 4d7d45e8ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 196 additions and 42 deletions

View File

@ -62,5 +62,8 @@ GCCTypeClass EvaluateBuiltinClassifyType(QualType T,
void HandleComplexComplexMul(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
llvm::APFloat D, llvm::APFloat &ResR,
llvm::APFloat &ResI);
void HandleComplexComplexDiv(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
llvm::APFloat D, llvm::APFloat &ResR,
llvm::APFloat &ResI);
#endif

View File

@ -15189,6 +15189,48 @@ void HandleComplexComplexMul(APFloat A, APFloat B, APFloat C, APFloat D,
}
}
void HandleComplexComplexDiv(APFloat A, APFloat B, APFloat C, APFloat D,
APFloat &ResR, APFloat &ResI) {
// This is an implementation of complex division according to the
// constraints laid out in C11 Annex G. The implementation uses the
// following naming scheme:
// (a + ib) / (c + id)
int DenomLogB = 0;
APFloat MaxCD = maxnum(abs(C), abs(D));
if (MaxCD.isFinite()) {
DenomLogB = ilogb(MaxCD);
C = scalbn(C, -DenomLogB, APFloat::rmNearestTiesToEven);
D = scalbn(D, -DenomLogB, APFloat::rmNearestTiesToEven);
}
APFloat Denom = C * C + D * D;
ResR =
scalbn((A * C + B * D) / Denom, -DenomLogB, APFloat::rmNearestTiesToEven);
ResI =
scalbn((B * C - A * D) / Denom, -DenomLogB, APFloat::rmNearestTiesToEven);
if (ResR.isNaN() && ResI.isNaN()) {
if (Denom.isPosZero() && (!A.isNaN() || !B.isNaN())) {
ResR = APFloat::getInf(ResR.getSemantics(), C.isNegative()) * A;
ResI = APFloat::getInf(ResR.getSemantics(), C.isNegative()) * B;
} else if ((A.isInfinity() || B.isInfinity()) && C.isFinite() &&
D.isFinite()) {
A = APFloat::copySign(APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0),
A);
B = APFloat::copySign(APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0),
B);
ResR = APFloat::getInf(ResR.getSemantics()) * (A * C + B * D);
ResI = APFloat::getInf(ResI.getSemantics()) * (B * C - A * D);
} else if (MaxCD.isInfinity() && A.isFinite() && B.isFinite()) {
C = APFloat::copySign(APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0),
C);
D = APFloat::copySign(APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0),
D);
ResR = APFloat::getZero(ResR.getSemantics()) * (A * C + B * D);
ResI = APFloat::getZero(ResI.getSemantics()) * (B * C - A * D);
}
}
}
bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
if (E->isPtrMemOp() || E->isAssignmentOp() || E->getOpcode() == BO_Comma)
return ExprEvaluatorBaseTy::VisitBinaryOperator(E);
@ -15326,39 +15368,7 @@ bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
// No real optimizations we can do here, stub out with zero.
B = APFloat::getZero(A.getSemantics());
}
int DenomLogB = 0;
APFloat MaxCD = maxnum(abs(C), abs(D));
if (MaxCD.isFinite()) {
DenomLogB = ilogb(MaxCD);
C = scalbn(C, -DenomLogB, APFloat::rmNearestTiesToEven);
D = scalbn(D, -DenomLogB, APFloat::rmNearestTiesToEven);
}
APFloat Denom = C * C + D * D;
ResR = scalbn((A * C + B * D) / Denom, -DenomLogB,
APFloat::rmNearestTiesToEven);
ResI = scalbn((B * C - A * D) / Denom, -DenomLogB,
APFloat::rmNearestTiesToEven);
if (ResR.isNaN() && ResI.isNaN()) {
if (Denom.isPosZero() && (!A.isNaN() || !B.isNaN())) {
ResR = APFloat::getInf(ResR.getSemantics(), C.isNegative()) * A;
ResI = APFloat::getInf(ResR.getSemantics(), C.isNegative()) * B;
} else if ((A.isInfinity() || B.isInfinity()) && C.isFinite() &&
D.isFinite()) {
A = APFloat::copySign(
APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0), A);
B = APFloat::copySign(
APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0), B);
ResR = APFloat::getInf(ResR.getSemantics()) * (A * C + B * D);
ResI = APFloat::getInf(ResI.getSemantics()) * (B * C - A * D);
} else if (MaxCD.isInfinity() && A.isFinite() && B.isFinite()) {
C = APFloat::copySign(
APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0), C);
D = APFloat::copySign(
APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0), D);
ResR = APFloat::getZero(ResR.getSemantics()) * (A * C + B * D);
ResI = APFloat::getZero(ResI.getSemantics()) * (B * C - A * D);
}
}
HandleComplexComplexDiv(A, B, C, D, ResR, ResI);
}
} else {
if (RHS.getComplexIntReal() == 0 && RHS.getComplexIntImag() == 0)

View File

@ -891,11 +891,14 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
if (const auto *AT = RHSType->getAs<AtomicType>())
RHSType = AT->getValueType();
bool LHSIsComplex = LHSType->isAnyComplexType();
unsigned LHSOffset;
bool RHSIsComplex = RHSType->isAnyComplexType();
// For ComplexComplex Mul, we have special ops to make their implementation
// easier.
BinaryOperatorKind Op = E->getOpcode();
if (Op == BO_Mul && LHSType->isAnyComplexType() &&
RHSType->isAnyComplexType()) {
if (Op == BO_Mul && LHSIsComplex && RHSIsComplex) {
assert(classifyPrim(LHSType->getAs<ComplexType>()->getElementType()) ==
classifyPrim(RHSType->getAs<ComplexType>()->getElementType()));
PrimType ElemT =
@ -907,18 +910,51 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
return this->emitMulc(ElemT, E);
}
if (Op == BO_Div && RHSIsComplex) {
QualType ElemQT = RHSType->getAs<ComplexType>()->getElementType();
PrimType ElemT = classifyPrim(ElemQT);
// If the LHS is not complex, we still need to do the full complex
// division, so just stub create a complex value and stub it out with
// the LHS and a zero.
if (!LHSIsComplex) {
// This is using the RHS type for the fake-complex LHS.
if (auto LHSO = allocateLocal(RHS))
LHSOffset = *LHSO;
else
return false;
if (!this->emitGetPtrLocal(LHSOffset, E))
return false;
if (!this->visit(LHS))
return false;
// real is LHS
if (!this->emitInitElem(ElemT, 0, E))
return false;
// imag is zero
if (!this->visitZeroInitializer(ElemT, ElemQT, E))
return false;
if (!this->emitInitElem(ElemT, 1, E))
return false;
} else {
if (!this->visit(LHS))
return false;
}
if (!this->visit(RHS))
return false;
return this->emitDivc(ElemT, E);
}
// Evaluate LHS and save value to LHSOffset.
bool LHSIsComplex;
unsigned LHSOffset;
if (LHSType->isAnyComplexType()) {
LHSIsComplex = true;
LHSOffset = this->allocateLocalPrimitive(LHS, PT_Ptr, true, false);
if (!this->visit(LHS))
return false;
if (!this->emitSetLocal(PT_Ptr, LHSOffset, E))
return false;
} else {
LHSIsComplex = false;
PrimType LHST = classifyPrim(LHSType);
LHSOffset = this->allocateLocalPrimitive(LHS, LHST, true, false);
if (!this->visit(LHS))
@ -928,17 +964,14 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
}
// Same with RHS.
bool RHSIsComplex;
unsigned RHSOffset;
if (RHSType->isAnyComplexType()) {
RHSIsComplex = true;
RHSOffset = this->allocateLocalPrimitive(RHS, PT_Ptr, true, false);
if (!this->visit(RHS))
return false;
if (!this->emitSetLocal(PT_Ptr, RHSOffset, E))
return false;
} else {
RHSIsComplex = false;
PrimType RHST = classifyPrim(RHSType);
RHSOffset = this->allocateLocalPrimitive(RHS, RHST, true, false);
if (!this->visit(RHS))
@ -1018,6 +1051,22 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
return false;
}
break;
case BO_Div:
assert(!RHSIsComplex);
if (!loadComplexValue(LHSIsComplex, false, ElemIndex, LHSOffset, LHS))
return false;
if (!loadComplexValue(RHSIsComplex, false, ElemIndex, RHSOffset, RHS))
return false;
if (ResultElemT == PT_Float) {
if (!this->emitDivf(getRoundingMode(E), E))
return false;
} else {
if (!this->emitDiv(ResultElemT, E))
return false;
}
break;
default:
return false;

View File

@ -425,6 +425,78 @@ inline bool Mulc(InterpState &S, CodePtr OpPC) {
return true;
}
template <PrimType Name, class T = typename PrimConv<Name>::T>
inline bool Divc(InterpState &S, CodePtr OpPC) {
const Pointer &RHS = S.Stk.pop<Pointer>();
const Pointer &LHS = S.Stk.pop<Pointer>();
const Pointer &Result = S.Stk.peek<Pointer>();
if constexpr (std::is_same_v<T, Floating>) {
APFloat A = LHS.atIndex(0).deref<Floating>().getAPFloat();
APFloat B = LHS.atIndex(1).deref<Floating>().getAPFloat();
APFloat C = RHS.atIndex(0).deref<Floating>().getAPFloat();
APFloat D = RHS.atIndex(1).deref<Floating>().getAPFloat();
APFloat ResR(A.getSemantics());
APFloat ResI(A.getSemantics());
HandleComplexComplexDiv(A, B, C, D, ResR, ResI);
// Copy into the result.
Result.atIndex(0).deref<Floating>() = Floating(ResR);
Result.atIndex(0).initialize();
Result.atIndex(1).deref<Floating>() = Floating(ResI);
Result.atIndex(1).initialize();
Result.initialize();
} else {
// Integer element type.
const T &LHSR = LHS.atIndex(0).deref<T>();
const T &LHSI = LHS.atIndex(1).deref<T>();
const T &RHSR = RHS.atIndex(0).deref<T>();
const T &RHSI = RHS.atIndex(1).deref<T>();
unsigned Bits = LHSR.bitWidth();
const T Zero = T::from(0, Bits);
if (Compare(RHSR, Zero) == ComparisonCategoryResult::Equal &&
Compare(RHSI, Zero) == ComparisonCategoryResult::Equal) {
const SourceInfo &E = S.Current->getSource(OpPC);
S.FFDiag(E, diag::note_expr_divide_by_zero);
return false;
}
// Den = real(RHS)² + imag(RHS)²
T A, B;
if (T::mul(RHSR, RHSR, Bits, &A) || T::mul(RHSI, RHSI, Bits, &B))
return false;
T Den;
if (T::add(A, B, Bits, &Den))
return false;
// real(Result) = ((real(LHS) * real(RHS)) + (imag(LHS) * imag(RHS))) / Den
T &ResultR = Result.atIndex(0).deref<T>();
T &ResultI = Result.atIndex(1).deref<T>();
if (T::mul(LHSR, RHSR, Bits, &A) || T::mul(LHSI, RHSI, Bits, &B))
return false;
if (T::add(A, B, Bits, &ResultR))
return false;
if (T::div(ResultR, Den, Bits, &ResultR))
return false;
Result.atIndex(0).initialize();
// imag(Result) = ((imag(LHS) * real(RHS)) - (real(LHS) * imag(RHS))) / Den
if (T::mul(LHSI, RHSR, Bits, &A) || T::mul(LHSR, RHSI, Bits, &B))
return false;
if (T::sub(A, B, Bits, &ResultI))
return false;
if (T::div(ResultI, Den, Bits, &ResultI))
return false;
Result.atIndex(1).initialize();
Result.initialize();
}
return true;
}
/// 1) Pops the RHS from the stack.
/// 2) Pops the LHS from the stack.
/// 3) Pushes 'LHS & RHS' on the stack

View File

@ -533,6 +533,10 @@ def Mulc : Opcode {
def Rem : IntegerOpcode;
def Div : IntegerOpcode;
def Divf : FloatOpcode;
def Divc : Opcode {
let Types = [NumberTypeClass];
let HasGroup = 1;
}
def BitAnd : IntegerOpcode;
def BitOr : IntegerOpcode;

View File

@ -40,6 +40,21 @@ constexpr _Complex int IIMC = IIMA * IIMB;
static_assert(__real(IIMC) == -30, "");
static_assert(__imag(IIMC) == 40, "");
static_assert(1.0j / 0.0 == 1); // both-error {{static assertion}} \
// both-note {{division by zero}}
static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) / (0.0 + 0.0j))) == 1);
static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) / 0.0)) == 1); // both-error {{static assertion}} \
// both-note {{division by zero}}
static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) / (0.0 + 0.0j))) == 1);
static_assert(__builtin_isinf_sign(__imag__((1.0 + InfC) / (0.0 + 0.0j))) == 1);
static_assert(__builtin_isinf_sign(__imag__((InfInf) / (0.0 + 0.0j))) == 1);
constexpr _Complex int IIDA = {10,20};
constexpr _Complex int IIDB = {1,2};
constexpr _Complex int IIDC = IIDA / IIDB;
static_assert(__real(IIDC) == 10, "");
static_assert(__imag(IIDC) == 0, "");
constexpr _Complex int Comma1 = {1, 2};
constexpr _Complex int Comma2 = (0, Comma1);
static_assert(Comma1 == Comma1, "");

View File

@ -1,4 +1,5 @@
// RUN: %clang_cc1 %s -std=c++1z -fsyntax-only -verify
// RUN: %clang_cc1 %s -std=c++1z -fsyntax-only -verify -fexperimental-new-constant-interpreter
//
// Test the constant folding of builtin complex numbers.