瀏覽代碼

LibPDF: Make SampledFunction::evaluate() work for n-dimensional input

I didn't find example code for this and the AI assistant did very
poorly on this as well. So I had to write it all by myself!

It can be much more efficient I think, but I think the overall
shape is maybe roughly fine.
Nico Weber 1 年之前
父節點
當前提交
f4a847894f
共有 2 個文件被更改,包括 92 次插入23 次删除
  1. 23 0
      Tests/LibPDF/TestPDF.cpp
  2. 69 23
      Userland/Libraries/LibPDF/Function.cpp

+ 23 - 0
Tests/LibPDF/TestPDF.cpp

@@ -137,6 +137,29 @@ TEST_CASE(sampled)
     EXPECT_EQ(MUST(f2->evaluate(Vector<float> { 0.5f })), (Vector<float> { 10.0f, 0.0f }));
     EXPECT_EQ(MUST(f2->evaluate(Vector<float> { 0.75f })), (Vector<float> { 5.0f, 4.0f }));
     EXPECT_EQ(MUST(f2->evaluate(Vector<float> { 1.0f })), (Vector<float> { 0.0f, 8.0f }));
+
+    auto f3 = MUST(make_sampled_function(Vector<u8> { { 0, 255, 0, 255, 0, 255 } }, { 0.0f, 1.0f, 0.0f, 1.0f }, { 0.0f, 10.0f }, { 3, 2 }));
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.0f, 0.0f })), Vector<float> { 0.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.25f, 0.0f })), Vector<float> { 5.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.5f, 0.0f })), Vector<float> { 10.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.75f, 0.0f })), Vector<float> { 5.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 1.0f, 0.0f })), Vector<float> { 0.0f });
+
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.0f, 0.5f })), Vector<float> { 5.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.25f, 0.5f })), Vector<float> { 5.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.5f, 0.5f })), Vector<float> { 5.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.75f, 0.5f })), Vector<float> { 5.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 1.0f, 0.5f })), Vector<float> { 5.0f });
+
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.0f, 1.0f })), Vector<float> { 10.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.25f, 1.0f })), Vector<float> { 5.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.5f, 1.0f })), Vector<float> { 0.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 0.75f, 1.0f })), Vector<float> { 5.0f });
+    EXPECT_EQ(MUST(f3->evaluate(Vector<float> { 1.0f, 1.0f })), Vector<float> { 10.0f });
+
+    auto f4 = MUST(make_sampled_function(Vector<u8> { { 0, 255, 255, 0, 0, 255, 255, 0 } }, { 0.0f, 1.0f, 0.0f, 1.0f }, { 0.0f, 10.0f, 0.0f, 8.0f }, { 2, 2 }));
+    EXPECT_EQ(MUST(f4->evaluate(Vector<float> { 0.0f, 0.0f })), (Vector<float> { 0.0f, 8.0f }));
+    EXPECT_EQ(MUST(f4->evaluate(Vector<float> { 0.5f, 0.5f })), (Vector<float> { 5.0f, 4.0f }));
 }
 
 static PDF::PDFErrorOr<NonnullRefPtr<PDF::Function>> make_postscript_function(StringView program, Vector<float> domain, Vector<float> range)

+ 69 - 23
Userland/Libraries/LibPDF/Function.cpp

@@ -28,6 +28,17 @@ public:
 private:
     SampledFunction(NonnullRefPtr<StreamObject>);
 
+    float sample(Vector<int> const& coordinates, size_t r) const
+    {
+        size_t stride = 1;
+        size_t offset = 0;
+        for (size_t i = 0; i < coordinates.size(); ++i) {
+            offset += coordinates[i] * stride;
+            stride *= m_sizes[i];
+        }
+        return m_sample_data[offset * m_range.size() + r];
+    }
+
     Vector<Bound> m_domain;
     Vector<Bound> m_range;
 
@@ -46,6 +57,7 @@ private:
     NonnullRefPtr<StreamObject> m_stream;
     ReadonlyBytes m_sample_data;
 
+    Vector<float> mutable m_inputs;
     Vector<float> mutable m_outputs;
 };
 
@@ -144,21 +156,19 @@ SampledFunction::create(Document* document, Vector<Bound> domain, Optional<Vecto
     function->m_order = order;
     function->m_encode = move(encode);
     function->m_decode = move(decode);
+    function->m_inputs.resize(function->m_domain.size());
     function->m_outputs.resize(function->m_range.size());
     return function;
 }
 
-PDFErrorOr<ReadonlySpan<float>> SampledFunction::evaluate(ReadonlySpan<float> x) const
+PDFErrorOr<ReadonlySpan<float>> SampledFunction::evaluate(ReadonlySpan<float> xs) const
 {
-    if (x.size() != m_domain.size())
+    if (xs.size() != m_domain.size())
         return Error { Error::Type::MalformedPDF, "Function argument size does not match domain size" };
 
     if (m_order != Order::Linear)
         return Error { Error::Type::RenderingUnsupported, "Sample function with cubic order not yet implemented" };
 
-    if (m_domain.size() != 1)
-        return Error { Error::Type::RenderingUnsupported, "Sample function with m > 1 not yet implemented" };
-
     if (m_bits_per_sample != 8)
         return Error { Error::Type::RenderingUnsupported, "Sample function with bits per sample != 8 not yet implemented" };
 
@@ -166,25 +176,61 @@ PDFErrorOr<ReadonlySpan<float>> SampledFunction::evaluate(ReadonlySpan<float> x)
         return y_min + (x - x_min) * (y_max - y_min) / (x_max - x_min);
     };
 
-    float xc = clamp(x[0], m_domain[0].lower, m_domain[0].upper);
-    float e = interpolate(xc, m_domain[0].lower, m_domain[0].upper, m_encode[0].lower, m_encode[0].upper);
-    float ec = clamp(e, 0.0f, static_cast<float>(m_sizes[0] - 1));
-
-    float e0 = floor(ec);
-    float e1 = ceil(ec);
-    if (e0 == e1) {
-        if (e0 == 0.0f)
-            e1 = 1.0f;
-        else
-            e0 = e1 - 1.0f;
+    for (size_t i = 0; i < m_domain.size(); ++i) {
+        float x = clamp(xs[i], m_domain[i].lower, m_domain[i].upper);
+        float e = interpolate(x, m_domain[i].lower, m_domain[i].upper, m_encode[i].lower, m_encode[i].upper);
+        float ec = clamp(e, 0.0f, static_cast<float>(m_sizes[i] - 1));
+        m_inputs[i] = ec;
     }
-    size_t plane_size = m_range.size();
-    for (size_t i = 0; i < m_range.size(); ++i) {
-        float s0 = m_sample_data[(size_t)e0 * plane_size + i];
-        float s1 = m_sample_data[(size_t)e1 * plane_size + i];
-        float r0 = interpolate(ec, e0, e1, s0, s1);
-        r0 = interpolate(r0, 0.0f, 255.0f, m_decode[i].lower, m_decode[i].upper);
-        m_outputs[i] = clamp(r0, m_range[i].lower, m_range[i].upper);
+
+    auto get_bounds = [](float x, float& low, float& high) {
+        low = floorf(x);
+        high = ceilf(x);
+        if (low == high) {
+            if (low == 0.0f)
+                high = 1.0f;
+            else
+                low = high - 1.0f;
+        }
+    };
+
+    for (size_t r = 0; r < m_range.size(); ++r) {
+        // For 1-D input data, we need to sample 2 points, one to the left and one to the right, and then interpolate between them.
+        // For 2-D input data, we need to sample 4 points (top-left, top-right, bottom-left, bottom-right),
+        // then reduce them to 2 points by interpolating along y, and then to 1 by interpolating along x.
+        // For 3-D input data, it's 8 points in a cube around the point, then reduce to 4 points by interpolating along z,
+        // then 2 by interpolating along y, then 1 by interpolating along x.
+        // So for the general case, we create 2**N samples, and then for each coordinate, we cut the number of samples in half
+        // by interpolating along that coordinate.
+        Vector<float> samples;
+        samples.resize(1 << m_domain.size());
+        // The i'th bit of mask indicates if the i'th coordinate is rounded up or down.
+        Vector<int> coordinates;
+        coordinates.resize(m_domain.size());
+        for (size_t mask = 0; mask < (1u << m_domain.size()); ++mask) {
+            for (size_t i = 0; i < m_domain.size(); ++i) {
+                float ec = m_inputs[i];
+                float e0, e1;
+                get_bounds(ec, e0, e1);
+                if ((mask & (1u << i)) != 0)
+                    ec = e1;
+                else
+                    ec = e0;
+                coordinates[i] = static_cast<int>(ec);
+            }
+            samples[mask] = sample(coordinates, r);
+        }
+
+        for (int i = static_cast<int>(m_domain.size() - 1); i >= 0; --i) {
+            float ec = m_inputs[i];
+            float e0, e1;
+            get_bounds(ec, e0, e1);
+            for (size_t mask = 0; mask < (1u << i); ++mask)
+                samples[mask] = interpolate(ec, e0, e1, samples[mask], samples[mask | (1 << i)]);
+        }
+
+        float result = interpolate(samples[0], 0.0f, 255.0f, m_decode[r].lower, m_decode[r].upper);
+        m_outputs[r] = clamp(result, m_range[r].lower, m_range[r].upper);
     }
 
     return m_outputs;