Prechádzať zdrojové kódy

LibC: Templatize unique filename enumeration for mkstemp() et al

This allows us to implement mkstemp() with open() directly, instead of
first lstat()'ing, and then open()'ing the filename.

Also implement tmpfile() in terms of mkstemp() instead of mktemp().
Andreas Kling 4 rokov pred
rodič
commit
b0f19c2af4

+ 1 - 6
Userland/Libraries/LibC/stdio.cpp

@@ -1205,16 +1205,11 @@ void funlockfile([[maybe_unused]] FILE* filehandle)
 FILE* tmpfile()
 {
     char tmp_path[] = "/tmp/XXXXXX";
-    if (__generate_unique_filename(tmp_path) < 0)
-        return nullptr;
-
-    int fd = open(tmp_path, O_CREAT | O_EXCL | O_RDWR, S_IWUSR | S_IRUSR);
+    int fd = mkstemp(tmp_path);
     if (fd < 0)
         return nullptr;
-
     // FIXME: instead of using this hack, implement with O_TMPFILE or similar
     unlink(tmp_path);
-
     return fdopen(fd, "rw");
 }
 }

+ 37 - 24
Userland/Libraries/LibC/stdlib.cpp

@@ -171,14 +171,13 @@ static bool is_either(char* str, int offset, char lower, char upper)
     return ch == lower || ch == upper;
 }
 
-__attribute__((warn_unused_result)) int __generate_unique_filename(char* pattern)
+template<typename Callback>
+inline int generate_unique_filename(char* pattern, Callback callback)
 {
     size_t length = strlen(pattern);
 
-    if (length < 6 || memcmp(pattern + length - 6, "XXXXXX", 6)) {
-        errno = EINVAL;
-        return -1;
-    }
+    if (length < 6 || memcmp(pattern + length - 6, "XXXXXX", 6))
+        return EINVAL;
 
     size_t start = length - 6;
 
@@ -187,13 +186,11 @@ __attribute__((warn_unused_result)) int __generate_unique_filename(char* pattern
     for (int attempt = 0; attempt < 100; ++attempt) {
         for (int i = 0; i < 6; ++i)
             pattern[start + i] = random_characters[(arc4random() % (sizeof(random_characters) - 1))];
-        struct stat st;
-        int rc = lstat(pattern, &st);
-        if (rc < 0 && errno == ENOENT)
+        if (callback() == IterationDecision::Break)
             return 0;
     }
-    errno = EEXIST;
-    return -1;
+
+    return EEXIST;
 }
 
 extern "C" {
@@ -727,31 +724,47 @@ int system(const char* command)
 
 char* mktemp(char* pattern)
 {
-    if (__generate_unique_filename(pattern) < 0)
+    auto error = generate_unique_filename(pattern, [&] {
+        struct stat st;
+        int rc = lstat(pattern, &st);
+        if (rc < 0 && errno == ENOENT)
+            return IterationDecision::Break;
+        return IterationDecision::Continue;
+    });
+    if (error) {
         pattern[0] = '\0';
-
+        errno = error;
+    }
     return pattern;
 }
 
 int mkstemp(char* pattern)
 {
-    char* path = mktemp(pattern);
-
-    int fd = open(path, O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR); // I'm using the flags I saw glibc using.
-    if (fd >= 0)
-        return fd;
-
-    return -1;
+    int fd = -1;
+    auto error = generate_unique_filename(pattern, [&] {
+        fd = open(pattern, O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR); // I'm using the flags I saw glibc using.
+        if (fd >= 0)
+            return IterationDecision::Break;
+        return IterationDecision::Continue;
+    });
+    if (error) {
+        errno = error;
+        return -1;
+    }
+    return fd;
 }
 
 char* mkdtemp(char* pattern)
 {
-    if (__generate_unique_filename(pattern) < 0)
-        return nullptr;
-
-    if (mkdir(pattern, 0700) < 0)
+    auto error = generate_unique_filename(pattern, [&] {
+        if (mkdir(pattern, 0700) == 0)
+            return IterationDecision::Break;
+        return IterationDecision::Continue;
+    });
+    if (error) {
+        errno = error;
         return nullptr;
-
+    }
     return pattern;
 }