feat: generate vertices in kernel shader

This commit is contained in:
dscyrescotti
2024-05-12 23:12:07 +07:00
parent de9deb7faf
commit 6356b88c9a
14 changed files with 186 additions and 97 deletions

View File

@@ -10,6 +10,8 @@
EC3565522BEFC65F00A4E0BF /* NSManagedObjectContext++.swift in Sources */ = {isa = PBXBuildFile; fileRef = EC3565512BEFC65F00A4E0BF /* NSManagedObjectContext++.swift */; };
EC3565542BEFC6AD00A4E0BF /* View++.swift in Sources */ = {isa = PBXBuildFile; fileRef = EC3565532BEFC6AD00A4E0BF /* View++.swift */; };
EC3565562BEFC7B300A4E0BF /* NSManagedObject++.swift in Sources */ = {isa = PBXBuildFile; fileRef = EC3565552BEFC7B300A4E0BF /* NSManagedObject++.swift */; };
EC35655A2BF060D900A4E0BF /* Quad.metal in Sources */ = {isa = PBXBuildFile; fileRef = EC3565592BF060D900A4E0BF /* Quad.metal */; };
EC35655C2BF0712A00A4E0BF /* Float++.swift in Sources */ = {isa = PBXBuildFile; fileRef = EC35655B2BF0712A00A4E0BF /* Float++.swift */; };
EC4538892BEBCAE000A86FEC /* Quad.swift in Sources */ = {isa = PBXBuildFile; fileRef = EC4538882BEBCAE000A86FEC /* Quad.swift */; };
EC7F6BEC2BE5E6E300A34A7B /* MemolaApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = EC7F6BEB2BE5E6E300A34A7B /* MemolaApp.swift */; };
EC7F6BF02BE5E6E400A34A7B /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = EC7F6BEF2BE5E6E400A34A7B /* Assets.xcassets */; };
@@ -80,6 +82,8 @@
EC3565512BEFC65F00A4E0BF /* NSManagedObjectContext++.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "NSManagedObjectContext++.swift"; sourceTree = "<group>"; };
EC3565532BEFC6AD00A4E0BF /* View++.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "View++.swift"; sourceTree = "<group>"; };
EC3565552BEFC7B300A4E0BF /* NSManagedObject++.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "NSManagedObject++.swift"; sourceTree = "<group>"; };
EC3565592BF060D900A4E0BF /* Quad.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Quad.metal; sourceTree = "<group>"; };
EC35655B2BF0712A00A4E0BF /* Float++.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "Float++.swift"; sourceTree = "<group>"; };
EC4538882BEBCAE000A86FEC /* Quad.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Quad.swift; sourceTree = "<group>"; };
EC7F6BE82BE5E6E300A34A7B /* Memola.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Memola.app; sourceTree = BUILT_PRODUCTS_DIR; };
EC7F6BEB2BE5E6E300A34A7B /* MemolaApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MemolaApp.swift; sourceTree = "<group>"; };
@@ -300,6 +304,7 @@
ECA738922BE6011100A4542E /* Stroke.metal */,
ECA738942BE6012D00A4542E /* ViewPort.metal */,
ECA738962BE6014200A4542E /* Graphic.metal */,
EC3565592BF060D900A4E0BF /* Quad.metal */,
);
path = Shaders;
sourceTree = "<group>";
@@ -348,6 +353,7 @@
EC3565512BEFC65F00A4E0BF /* NSManagedObjectContext++.swift */,
EC3565532BEFC6AD00A4E0BF /* View++.swift */,
EC3565552BEFC7B300A4E0BF /* NSManagedObject++.swift */,
EC35655B2BF0712A00A4E0BF /* Float++.swift */,
);
path = Extensions;
sourceTree = "<group>";
@@ -594,6 +600,7 @@
ECA738E42BE6110800A4542E /* Drawable.swift in Sources */,
ECA738AD2BE60CC600A4542E /* DrawingView.swift in Sources */,
ECA738E02BE610B900A4542E /* EraserRenderPass.swift in Sources */,
EC35655A2BF060D900A4E0BF /* Quad.metal in Sources */,
ECA738912BE600F500A4542E /* Cache.metal in Sources */,
ECA7389C2BE601AF00A4542E /* GridVertex.swift in Sources */,
ECA738A82BE6025900A4542E /* GraphicUniforms.swift in Sources */,
@@ -631,6 +638,7 @@
ECFA15242BEF223300455818 /* GraphicContextObject.swift in Sources */,
EC3565562BEFC7B300A4E0BF /* NSManagedObject++.swift in Sources */,
ECA738EC2BE6124E00A4542E /* CGAffineTransform++.swift in Sources */,
EC35655C2BF0712A00A4E0BF /* Float++.swift in Sources */,
ECA738E22BE610D000A4542E /* GraphicRenderPass.swift in Sources */,
ECA738DC2BE6108D00A4542E /* StrokeRenderPass.swift in Sources */,
ECA738F42BE612A000A4542E /* Array++.swift in Sources */,

View File

@@ -60,12 +60,15 @@ final class GraphicContext: @unchecked Sendable {
}
extension GraphicContext {
func load() {
func loadStrokes() {
guard let object else { return }
self.strokes = object.strokes.compactMap { stroke -> Stroke? in
guard let stroke = stroke as? StrokeObject else { return nil }
let _stroke = Stroke(object: stroke)
_stroke.loadVertices()
_stroke.loadQuads()
withPersistence(\.backgroundContext) { [stroke] context in
context.refresh(stroke, mergeChanges: false)
}
return _stroke
}
}
@@ -129,7 +132,6 @@ extension GraphicContext {
currentStroke.saveQuads(for: quads)
try context.saveIfNeeded()
if let stroke = currentStroke.object {
currentStroke.quads.removeAll()
context.refresh(stroke, mergeChanges: false)
}
}

View File

@@ -53,7 +53,7 @@ extension Canvas {
}
let graphicContext = canvas.graphicContext
self?.graphicContext.object = graphicContext
self?.graphicContext.load()
self?.graphicContext.loadStrokes()
context.refresh(canvas, mergeChanges: false)
DispatchQueue.main.async { [weak self] in
self?.state = .loaded

View File

@@ -131,4 +131,13 @@ struct PipelineStates {
}
return try? device.makeComputePipelineState(function: function)
}
static func createQuadPipelineState(from renderer: Renderer) -> MTLComputePipelineState? {
let device = renderer.device
let library = renderer.library
guard let function = library.makeFunction(name: "generate_stroke_vertices") else {
return nil
}
return try? device.makeComputePipelineState(function: function)
}
}

View File

@@ -6,71 +6,51 @@
//
import CoreData
import MetalKit
import Foundation
struct Quad {
var originX: CGFloat
var originY: CGFloat
var size: CGFloat
var rotation: CGFloat
var originX: Float
var originY: Float
var size: Float
var rotation: Float
var shape: Int16
var color: vector_float4
init(object: QuadObject) {
self.originX = object.originX
self.originY = object.originY
self.size = object.size
self.rotation = object.rotation
self.originX = object.originX.float
self.originY = object.originY.float
self.size = object.size.float
self.rotation = object.rotation.float
self.shape = object.shape
self.color = [
object.color[0].float,
object.color[1].float,
object.color[2].float,
object.color[3].float
]
}
init(origin: CGPoint, size: CGFloat, rotation: CGFloat, shape: Int16) {
self.originX = origin.x
self.originY = origin.y
self.size = size
self.rotation = rotation
init(origin: CGPoint, size: CGFloat, rotation: CGFloat, shape: Int16, color: [CGFloat]) {
self.originX = origin.x.float
self.originY = origin.y.float
self.size = size.float
self.rotation = rotation.float
self.shape = shape
}
var origin: CGPoint {
get { CGPoint(x: originX, y: originY) }
set {
originX = newValue.x
originY = newValue.y
}
}
func generateVertices(_ color: [CGFloat]) -> [QuadVertex] {
guard let shape = QuadShape.init(rawValue: shape) else { return [] }
switch shape {
case .rounded:
return generateRoundedQuad(color)
case .squared:
return generateSquaredQuad(color)
}
}
func generateRoundedQuad(_ color: [CGFloat]) -> [QuadVertex] {
let halfSize = size * 0.5
return [
QuadVertex(x: origin.x - halfSize, y: origin.y - halfSize, textCoord: CGPoint(x: 0, y: 0), color: color, origin: origin, rotation: rotation),
QuadVertex(x: origin.x + halfSize, y: origin.y - halfSize, textCoord: CGPoint(x: 1, y: 0), color: color, origin: origin, rotation: rotation),
QuadVertex(x: origin.x - halfSize, y: origin.y + halfSize, textCoord: CGPoint(x: 0, y: 1), color: color, origin: origin, rotation: rotation),
QuadVertex(x: origin.x + halfSize, y: origin.y - halfSize, textCoord: CGPoint(x: 1, y: 0), color: color, origin: origin, rotation: rotation),
QuadVertex(x: origin.x - halfSize, y: origin.y + halfSize, textCoord: CGPoint(x: 0, y: 1), color: color, origin: origin, rotation: rotation),
QuadVertex(x: origin.x + halfSize, y: origin.y + halfSize, textCoord: CGPoint(x: 1, y: 1), color: color, origin: origin, rotation: rotation)
]
}
func generateSquaredQuad(_ color: [CGFloat]) -> [QuadVertex] {
let vHalfSize = size * 0.5
let hHalfSize = size * 0.15
return [
QuadVertex(x: origin.x - hHalfSize, y: origin.y - vHalfSize, textCoord: CGPoint(x: 0, y: 0), color: color, origin: origin, rotation: rotation),
QuadVertex(x: origin.x + hHalfSize, y: origin.y - vHalfSize, textCoord: CGPoint(x: 1, y: 0), color: color, origin: origin, rotation: rotation),
QuadVertex(x: origin.x - hHalfSize, y: origin.y + vHalfSize, textCoord: CGPoint(x: 0, y: 1), color: color, origin: origin, rotation: rotation),
QuadVertex(x: origin.x + hHalfSize, y: origin.y - vHalfSize, textCoord: CGPoint(x: 1, y: 0), color: color, origin: origin, rotation: rotation),
QuadVertex(x: origin.x - hHalfSize, y: origin.y + vHalfSize, textCoord: CGPoint(x: 0, y: 1), color: color, origin: origin, rotation: rotation),
QuadVertex(x: origin.x + hHalfSize, y: origin.y + vHalfSize, textCoord: CGPoint(x: 1, y: 1), color: color, origin: origin, rotation: rotation)
]
self.color = [color[0].float, color[1].float, color[2].float, color[3].float]
}
}
extension Quad {
var origin: CGPoint {
get { CGPoint(x: originX.cgFloat, y: originY.cgFloat) }
set {
originX = newValue.x.float
originY = newValue.y.float
}
}
func getColor() -> [CGFloat] {
[color.x.cgFloat, color.y.cgFloat, color.z.cgFloat, color.w.cgFloat]
}
}

View File

@@ -27,7 +27,7 @@ struct SolidPointStrokeGenerator: StrokeGenerator {
let control = CGPoint.middle(p1: start, p2: end)
addCurve(from: start, to: end, by: control, on: stroke)
case 3:
discardVertices(upto: stroke.vertexIndex, quadIndex: stroke.quadIndex, on: stroke)
stroke.removeQuads(from: stroke.quadIndex + 1)
let index = stroke.keyPoints.count - 1
var start = stroke.keyPoints[index - 2]
var end = CGPoint.middle(p1: stroke.keyPoints[index - 2], p2: stroke.keyPoints[index - 1])
@@ -62,7 +62,7 @@ struct SolidPointStrokeGenerator: StrokeGenerator {
}
private func smoothOutPath(on stroke: Stroke) {
discardVertices(upto: stroke.vertexIndex, quadIndex: stroke.quadIndex, on: stroke)
stroke.removeQuads(from: stroke.quadIndex + 1)
adjustPreviousKeyPoint(on: stroke)
switch stroke.keyPoints.count {
case 4:
@@ -80,7 +80,6 @@ struct SolidPointStrokeGenerator: StrokeGenerator {
addCurve(from: start, to: end, by: control, on: stroke)
}
stroke.quadIndex = stroke.quads.count - 1
stroke.vertexIndex = stroke.vertices.endIndex - 1
}
private func adjustPreviousKeyPoint(on stroke: Stroke) {
@@ -107,8 +106,7 @@ struct SolidPointStrokeGenerator: StrokeGenerator {
rotation = CGFloat.random(in: 0...360) * .pi / 180
}
let quad = stroke.addQuad(at: point, rotation: rotation, shape: .rounded)
stroke.vertices.append(contentsOf: quad.generateVertices(stroke.color))
stroke.vertexCount = stroke.vertices.endIndex
stroke.quads.append(quad)
}
private func addCurve(from start: CGPoint, to end: CGPoint, by control: CGPoint, on stroke: Stroke) {
@@ -131,17 +129,6 @@ struct SolidPointStrokeGenerator: StrokeGenerator {
addPoint(point, on: stroke)
}
}
private func discardVertices(upto index: Int, quadIndex: Int, on stroke: Stroke) {
if index < 0 {
stroke.vertices.removeAll()
} else {
let count = stroke.vertices.endIndex
let dropCount = count - (max(0, index) + 1)
stroke.vertices.removeLast(dropCount)
}
stroke.removeQuads(from: quadIndex + 1)
}
}
extension SolidPointStrokeGenerator {

View File

@@ -41,27 +41,21 @@ final class Stroke: @unchecked Sendable {
}
var angle: CGFloat = 0
var penStyle: Style {
Style(rawValue: style) ?? .marker
}
var batchIndex: Int = 0
var quadIndex: Int = -1
var vertexIndex: Int = -1
var keyPoints: [CGPoint] = []
var thicknessFactor: CGFloat = 0.7
var vertices: [QuadVertex] = []
var vertexBuffer: MTLBuffer?
var vertexCount: Int = 0
var texture: MTLTexture?
var isEmpty: Bool {
vertices.isEmpty
quads.isEmpty
}
var isEraserPenStyle: Bool {
penStyle == .eraser
}
@@ -78,14 +72,15 @@ final class Stroke: @unchecked Sendable {
penStyle.anyPenStyle.generator.finish(at: point, on: self)
keyPoints.removeAll()
}
}
func loadVertices() {
extension Stroke {
func loadQuads() {
guard let object else { return }
for quad in object.quads {
guard let quad = quad as? QuadObject else { continue }
vertices.append(contentsOf: Quad(object: quad).generateVertices(object.color))
quads = object.quads.compactMap { quad in
guard let quad = quad as? QuadObject else { return nil }
return Quad(object: quad)
}
vertexCount = vertices.endIndex
}
func addQuad(at point: CGPoint, rotation: CGFloat, shape: QuadShape) -> Quad {
@@ -93,7 +88,8 @@ final class Stroke: @unchecked Sendable {
origin: point,
size: thickness,
rotation: rotation,
shape: shape.rawValue
shape: shape.rawValue,
color: color
)
quads.append(quad)
return quad
@@ -112,11 +108,12 @@ final class Stroke: @unchecked Sendable {
func saveQuads(for quads: [Quad]) {
for _quad in quads {
let quad = QuadObject(\.backgroundContext)
quad.originX = _quad.originX
quad.originY = _quad.originY
quad.size = _quad.size
quad.rotation = _quad.rotation
quad.originX = _quad.originX.cgFloat
quad.originY = _quad.originY.cgFloat
quad.size = _quad.size.cgFloat
quad.rotation = _quad.rotation.cgFloat
quad.shape = _quad.shape
quad.color = _quad.getColor()
quad.stroke = object
object?.quads.add(quad)
}
@@ -128,7 +125,6 @@ extension Stroke: Drawable {
if texture == nil {
texture = penStyle.anyPenStyle.loadTexture(on: device)
}
vertexBuffer = device.makeBuffer(bytes: &vertices, length: MemoryLayout<QuadVertex>.stride * vertexCount, options: .cpuCacheModeWriteCombined)
}
func draw(device: MTLDevice, renderEncoder: MTLRenderCommandEncoder) {
@@ -136,7 +132,8 @@ extension Stroke: Drawable {
prepare(device: device)
renderEncoder.setFragmentTexture(texture, index: 0)
renderEncoder.setVertexBuffer(vertexBuffer, offset: 0, index: 0)
renderEncoder.drawPrimitives(type: .triangle, vertexStart: 0, vertexCount: vertexCount)
renderEncoder.drawPrimitives(type: .triangle, vertexStart: 0, vertexCount: quads.endIndex * 6)
vertexBuffer = nil
}
}

View File

@@ -15,6 +15,7 @@ class StrokeRenderPass: RenderPass {
weak var graphicDescriptor: MTLRenderPassDescriptor?
var strokePipelineState: MTLRenderPipelineState?
var quadPipelineState: MTLComputePipelineState?
weak var graphicPipelineState: MTLRenderPipelineState?
var stroke: Stroke?
@@ -23,6 +24,7 @@ class StrokeRenderPass: RenderPass {
init(renderer: Renderer) {
descriptor = MTLRenderPassDescriptor()
strokePipelineState = PipelineStates.createStrokePipelineState(from: renderer)
quadPipelineState = PipelineStates.createQuadPipelineState(from: renderer)
}
func resize(on view: MTKView, to size: CGSize, with renderer: Renderer) {
@@ -33,6 +35,8 @@ class StrokeRenderPass: RenderPass {
func draw(on canvas: Canvas, with renderer: Renderer) {
guard let descriptor else { return }
generateVertexBuffer(on: canvas, with: renderer)
guard let strokeTexture else { return }
descriptor.colorAttachments[0].texture = strokeTexture
descriptor.colorAttachments[0].clearColor = MTLClearColor(red: 1, green: 1, blue: 1, alpha: 0)
@@ -50,14 +54,38 @@ class StrokeRenderPass: RenderPass {
canvas.setUniformsBuffer(device: renderer.device, renderEncoder: renderEncoder)
stroke?.draw(device: renderer.device, renderEncoder: renderEncoder)
renderEncoder.endEncoding()
commandBuffer.commit()
drawStrokeTexture(on: canvas, with: renderer)
}
func drawStrokeTexture(on canvas: Canvas, with renderer: Renderer) {
private func generateVertexBuffer(on canvas: Canvas, with renderer: Renderer) {
guard let stroke, !stroke.quads.isEmpty, let quadPipelineState else { return }
guard let quadCommandBuffer = renderer.commandQueue.makeCommandBuffer() else { return }
guard let computeEncoder = quadCommandBuffer.makeComputeCommandEncoder() else { return }
computeEncoder.label = "Quad Render Pass"
let quadCount = stroke.quads.endIndex
var quads = stroke.quads
let quadBuffer = renderer.device.makeBuffer(bytes: &quads, length: MemoryLayout<Quad>.stride * quadCount, options: [])
let vertexBuffer = renderer.device.makeBuffer(length: MemoryLayout<QuadVertex>.stride * quadCount * 6, options: [])
computeEncoder.setComputePipelineState(quadPipelineState)
computeEncoder.setBuffer(quadBuffer, offset: 0, index: 0)
computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 1)
stroke.vertexBuffer = vertexBuffer
let threadsPerGroup = MTLSize(width: 1, height: 1, depth: 1)
let numThreadgroups = MTLSize(width: quadCount + 1, height: 1, depth: 1)
computeEncoder.dispatchThreadgroups(numThreadgroups, threadsPerThreadgroup: threadsPerGroup)
computeEncoder.endEncoding()
quadCommandBuffer.commit()
}
private func drawStrokeTexture(on canvas: Canvas, with renderer: Renderer) {
guard let stroke else { return }
guard let graphicDescriptor, let graphicPipelineState else { return }

View File

@@ -0,0 +1,54 @@
//
// Quad.metal
// Memola
//
// Created by Dscyre Scotti on 5/12/24.
//
#include <metal_stdlib>
using namespace metal;
struct Quad {
float originX;
float originY;
float size;
float rotation;
int shape;
float4 color;
};
struct Vertex {
float4 position;
float2 textCoord;
float4 color;
float2 origin;
float rotation;
};
Vertex createVertex(Quad quad, float2 factor, float2 textCoord) {
Vertex output;
float x = quad.originX + factor.x;
float y = quad.originY + factor.y;
output.position = float4(x, y, 0, 1);
output.textCoord = textCoord;
output.color = quad.color;
output.origin = float2(quad.originX, quad.originY);
output.rotation = quad.rotation;
return output;
}
kernel void generate_stroke_vertices(
device Quad *quads [[buffer(0)]],
device Vertex *vertices [[buffer(1)]],
uint gid [[thread_position_in_grid]]
) {
uint index = gid * 6;
Quad quad = quads[gid];
float halfSize = quad.size * 0.5;
vertices[index] = createVertex(quad, float2(-halfSize, -halfSize), float2(0, 0));
vertices[index + 1] = createVertex(quad, float2(halfSize, -halfSize), float2(1, 0));
vertices[index + 2] = createVertex(quad, float2(-halfSize, halfSize), float2(0, 1));
vertices[index + 3] = createVertex(quad, float2(halfSize, -halfSize), float2(1, 0));
vertices[index + 4] = createVertex(quad, float2(-halfSize, halfSize), float2(0, 1));
vertices[index + 5] = createVertex(quad, float2(halfSize, halfSize), float2(1, 1));
}

View File

@@ -62,6 +62,9 @@ class CanvasViewController: UIViewController {
override func viewDidDisappear(_ animated: Bool) {
super.viewDidDisappear(animated)
history.resetRedo()
withPersistence(\.backgroundContext) { context in
context.refreshAllObjects()
}
}
}

View File

@@ -0,0 +1,14 @@
//
// Float++.swift
// Memola
//
// Created by Dscyre Scotti on 5/12/24.
//
import Foundation
extension Float {
var cgFloat: CGFloat {
CGFloat(self)
}
}

View File

@@ -31,6 +31,11 @@ struct MemosView: View {
}
.fullScreenCover(item: $memo) { memo in
MemoView(memo: memo)
.onDisappear {
withPersistence(\.viewContext) { context in
context.refreshAllObjects()
}
}
}
}

View File

@@ -15,5 +15,6 @@ final class QuadObject: NSManagedObject {
@NSManaged var size: CGFloat
@NSManaged var rotation: CGFloat
@NSManaged var shape: Int16
@NSManaged var color: [CGFloat]
@NSManaged var stroke: StrokeObject?
}

View File

@@ -17,6 +17,7 @@
<relationship name="canvas" maxCount="1" deletionRule="Cascade" destinationEntity="CanvasObject" inverseName="memo" inverseEntity="CanvasObject"/>
</entity>
<entity name="QuadObject" representedClassName="QuadObject" syncable="YES">
<attribute name="color" attributeType="Transformable" valueTransformerName="NSSecureUnarchiveFromDataTransformer" customClassName="[CGFloat]"/>
<attribute name="originX" attributeType="Double" defaultValueString="0.0" usesScalarValueType="YES"/>
<attribute name="originY" attributeType="Double" defaultValueString="0.0" usesScalarValueType="YES"/>
<attribute name="rotation" attributeType="Double" defaultValueString="0.0" usesScalarValueType="YES"/>