[flang][runtime] Interoperable POINTER deallocation validation (#96100)

Extend the runtime validation of deallocated pointers so that it also
works when pointers are allocated &/or deallocated outside Fortran.
Previously, bogus runtime errors would be reported for pointers
allocated via CFI_allocate() and deallocated in Fortran, and
CFI_deallocate() did not check that it was deallocating a whole
contiguous pointer that was allocated as such.
This commit is contained in:
Peter Klausler 2024-06-24 10:46:30 -07:00 committed by GitHub
parent eac925fb81
commit 514c1ec547
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 29 deletions

View File

@ -115,6 +115,11 @@ bool RTDECL(PointerIsAssociated)(const Descriptor &);
bool RTDECL(PointerIsAssociatedWith)(
const Descriptor &, const Descriptor *target);
// Fortran POINTERs are allocated with an extra validation word after their
// payloads in order to detect erroneous deallocations later.
RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t);
RT_API_ATTRS bool ValidatePointerPayload(const ISO::CFI_cdesc_t &);
} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_POINTER_H_

View File

@ -13,6 +13,7 @@
#include "terminator.h"
#include "flang/ISO_Fortran_binding_wrapper.h"
#include "flang/Runtime/descriptor.h"
#include "flang/Runtime/pointer.h"
#include "flang/Runtime/type-code.h"
#include <cstdlib>
@ -75,7 +76,7 @@ RT_API_ATTRS int CFI_allocate(CFI_cdesc_t *descriptor,
dim->sm = byteSize;
byteSize *= extent;
}
void *p{byteSize ? std::malloc(byteSize) : std::malloc(1)};
void *p{runtime::AllocateValidatedPointerPayload(byteSize)};
if (!p && byteSize) {
return CFI_ERROR_MEM_ALLOCATION;
}
@ -91,8 +92,11 @@ RT_API_ATTRS int CFI_deallocate(CFI_cdesc_t *descriptor) {
if (descriptor->version != CFI_VERSION) {
return CFI_INVALID_DESCRIPTOR;
}
if (descriptor->attribute != CFI_attribute_allocatable &&
descriptor->attribute != CFI_attribute_pointer) {
if (descriptor->attribute == CFI_attribute_pointer) {
if (!runtime::ValidatePointerPayload(*descriptor)) {
return CFI_INVALID_DESCRIPTOR;
}
} else if (descriptor->attribute != CFI_attribute_allocatable) {
// Non-interoperable object
return CFI_INVALID_DESCRIPTOR;
}

View File

@ -199,7 +199,16 @@ RT_API_ATTRS int Descriptor::Destroy(
}
}
RT_API_ATTRS int Descriptor::Deallocate() { return ISO::CFI_deallocate(&raw_); }
RT_API_ATTRS int Descriptor::Deallocate() {
ISO::CFI_cdesc_t &descriptor{raw()};
if (!descriptor.base_addr) {
return CFI_ERROR_BASE_ADDR_NULL;
} else {
std::free(descriptor.base_addr);
descriptor.base_addr = nullptr;
return CFI_SUCCESS;
}
}
RT_API_ATTRS bool Descriptor::DecrementSubscripts(
SubscriptValue *subscript, const int *permutation) const {

View File

@ -124,6 +124,23 @@ void RTDEF(PointerAssociateRemapping)(Descriptor &pointer,
}
}
RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t byteSize) {
// Add space for a footer to validate during deallocation.
constexpr std::size_t align{sizeof(std::uintptr_t)};
byteSize = ((byteSize / align) + 1) * align;
std::size_t total{byteSize + sizeof(std::uintptr_t)};
void *p{std::malloc(total)};
if (p) {
// Fill the footer word with the XOR of the ones' complement of
// the base address, which is a value that would be highly unlikely
// to appear accidentally at the right spot.
std::uintptr_t *footer{
reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
*footer = ~reinterpret_cast<std::uintptr_t>(p);
}
return p;
}
int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
@ -137,22 +154,12 @@ int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
elementBytes = pointer.raw().elem_len = 0;
}
std::size_t byteSize{pointer.Elements() * elementBytes};
// Add space for a footer to validate during DEALLOCATE.
constexpr std::size_t align{sizeof(std::uintptr_t)};
byteSize = ((byteSize + align - 1) / align) * align;
std::size_t total{byteSize + sizeof(std::uintptr_t)};
void *p{std::malloc(total)};
void *p{AllocateValidatedPointerPayload(byteSize)};
if (!p) {
return ReturnError(terminator, CFI_ERROR_MEM_ALLOCATION, errMsg, hasStat);
}
pointer.set_base_addr(p);
pointer.SetByteStrides();
// Fill the footer word with the XOR of the ones' complement of
// the base address, which is a value that would be highly unlikely
// to appear accidentally at the right spot.
std::uintptr_t *footer{
reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
*footer = ~reinterpret_cast<std::uintptr_t>(p);
int stat{StatOk};
if (const DescriptorAddendum * addendum{pointer.Addendum()}) {
if (const auto *derived{addendum->derivedType()}) {
@ -176,6 +183,27 @@ int RTDEF(PointerAllocateSource)(Descriptor &pointer, const Descriptor &source,
return stat;
}
static RT_API_ATTRS std::size_t GetByteSize(
const ISO::CFI_cdesc_t &descriptor) {
std::size_t rank{descriptor.rank};
const ISO::CFI_dim_t *dim{descriptor.dim};
std::size_t byteSize{descriptor.elem_len};
for (std::size_t j{0}; j < rank; ++j) {
byteSize *= dim[j].extent;
}
return byteSize;
}
bool RT_API_ATTRS ValidatePointerPayload(const ISO::CFI_cdesc_t &desc) {
std::size_t byteSize{GetByteSize(desc)};
constexpr std::size_t align{sizeof(std::uintptr_t)};
byteSize = ((byteSize / align) + 1) * align;
const void *p{desc.base_addr};
const std::uintptr_t *footer{reinterpret_cast<const std::uintptr_t *>(
static_cast<const char *>(p) + byteSize)};
return *footer == ~reinterpret_cast<std::uintptr_t>(p);
}
int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
@ -185,20 +213,9 @@ int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat,
if (!pointer.IsAllocated()) {
return ReturnError(terminator, StatBaseNull, errMsg, hasStat);
}
if (executionEnvironment.checkPointerDeallocation) {
// Validate the footer. This should fail if the pointer doesn't
// span the entire object, or the object was not allocated as a
// pointer.
std::size_t byteSize{pointer.Elements() * pointer.ElementBytes()};
constexpr std::size_t align{sizeof(std::uintptr_t)};
byteSize = ((byteSize + align - 1) / align) * align;
void *p{pointer.raw().base_addr};
std::uintptr_t *footer{
reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
if (*footer != ~reinterpret_cast<std::uintptr_t>(p)) {
return ReturnError(
terminator, StatBadPointerDeallocation, errMsg, hasStat);
}
if (executionEnvironment.checkPointerDeallocation &&
!ValidatePointerPayload(pointer.raw())) {
return ReturnError(terminator, StatBadPointerDeallocation, errMsg, hasStat);
}
return ReturnError(terminator,
pointer.Destroy(/*finalize=*/true, /*destroyPointers=*/true, &terminator),