Skip to content
Snippets Groups Projects
curve_assign_shapekey.py 38.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • #
    #
    # This Blender add-on assigns one or more Bezier Curves as shape keys to another
    # Bezier Curve
    #
    # Supported Blender Version: 2.80 Beta
    #
    # Copyright (C) 2019  Shrinivas Kulkarni
    #
    
    # License: GPL-3.0 (https://github.com/Shriinivas/assignshapekey/blob/master/LICENSE)
    
    #
    
    import bpy, bmesh, bgl, gpu
    from gpu_extras.batch import batch_for_shader
    
    from bpy.props import BoolProperty, EnumProperty, StringProperty
    
    from collections import OrderedDict
    from mathutils import Vector
    from math import sqrt, floor
    from functools import cmp_to_key
    
    from bpy.types import Panel, Operator, AddonPreferences
    
    
    
    bl_info = {
        "name": "Assign Shape Keys",
        "author": "Shrinivas Kulkarni",
        "version": (1, 0, 0),
    
        "location": "View 3D > Sidebar > Create Tab",
    
        "description": "Assigns one or more Bezier curves as shape keys to another Bezier curve",
    
        "category": "Add Curve",
    
    30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765
        "wiki_url": "https://github.com/Shriinivas/assignshapekey/blob/master/README.md",
        "blender": (2, 80, 0),
    }
    
    matchList = [('vCnt', 'Vertex Count', 'Match by vertex count'),
                ('bbArea', 'Area', 'Match by surface area of the bounding box'), \
                ('bbHeight', 'Height', 'Match by bounding box height'), \
                ('bbWidth', 'Width', 'Match by bounding box width'),
                ('bbDepth', 'Depth', 'Match by bounding box depth'),
                ('minX', 'Min X', 'Match by  bounding bon Min X'),
                ('maxX', 'Max X', 'Match by  bounding bon Max X'),
                ('minY', 'Min Y', 'Match by  bounding bon Min Y'),
                ('maxY', 'Max Y', 'Match by  bounding bon Max Y'),
                ('minZ', 'Min Z', 'Match by  bounding bon Min Z'),
                ('maxZ', 'Max Z', 'Match by  bounding bon Max Z')]
    
    DEF_ERR_MARGIN = 0.0001
    
    def isBezier(obj):
        return obj.type == 'CURVE' and len(obj.data.splines) > 0 \
            and obj.data.splines[0].type == 'BEZIER'
    
    #Avoid errors due to floating point conversions/comparisons
    #TODO: return -1, 0, 1
    def floatCmpWithMargin(float1, float2, margin = DEF_ERR_MARGIN):
        return abs(float1 - float2) < margin
    
    def vectCmpWithMargin(v1, v2, margin = DEF_ERR_MARGIN):
        return all(floatCmpWithMargin(v1[i], v2[i], margin) for i in range(0, len(v1)))
    
    class Segment():
    
        #pts[0] - start, pts[1] - ctrl1, pts[2] - ctrl2, , pts[3] - end
        def pointAtT(pts, t):
            return pts[0] + t * (3 * (pts[1] - pts[0]) +
                t* (3 * (pts[0] + pts[2]) - 6 * pts[1] +
                    t * (-pts[0] + 3 * (pts[1] - pts[2]) + pts[3])))
    
        def getSegLenRecurs(pts, start, end, t1 = 0, t2 = 1, error = DEF_ERR_MARGIN):
            t1_5 = (t1 + t2)/2
            mid = Segment.pointAtT(pts, t1_5)
            l = (end - start).length
            l2 = (mid - start).length + (end - mid).length
            if (l2 - l > error):
                return (Segment.getSegLenRecurs(pts, start, mid, t1, t1_5, error) +
                        Segment.getSegLenRecurs(pts, mid, end, t1_5, t2, error))
            return l2
    
        def __init__(self, start, ctrl1, ctrl2, end):
            self.start = start
            self.ctrl1 = ctrl1
            self.ctrl2 = ctrl2
            self.end = end
            pts = [start, ctrl1, ctrl2, end]
            self.length = Segment.getSegLenRecurs(pts, start, end)
    
        #see https://stackoverflow.com/questions/878862/drawing-part-of-a-b%c3%a9zier-curve-by-reusing-a-basic-b%c3%a9zier-curve-function/879213#879213
        def partialSeg(self, t0, t1):
            pts = [self.start, self.ctrl1, self.ctrl2, self.end]
    
            if(t0 > t1):
                tt = t1
                t1 = t0
                t0 = tt
    
            #Let's make at least the line segments of predictable length :)
            if(pts[0] == pts[1] and pts[2] == pts[3]):
                pt0 = Vector([(1 - t0) * pts[0][i] + t0 * pts[2][i] for i in range(0, 3)])
                pt1 = Vector([(1 - t1) * pts[0][i] + t1 * pts[2][i] for i in range(0, 3)])
                return Segment(pt0, pt0, pt1, pt1)
    
            u0 = 1.0 - t0
            u1 = 1.0 - t1
    
            qa = [pts[0][i]*u0*u0 + pts[1][i]*2*t0*u0 + pts[2][i]*t0*t0 for i in range(0, 3)]
            qb = [pts[0][i]*u1*u1 + pts[1][i]*2*t1*u1 + pts[2][i]*t1*t1 for i in range(0, 3)]
            qc = [pts[1][i]*u0*u0 + pts[2][i]*2*t0*u0 + pts[3][i]*t0*t0 for i in range(0, 3)]
            qd = [pts[1][i]*u1*u1 + pts[2][i]*2*t1*u1 + pts[3][i]*t1*t1 for i in range(0, 3)]
    
            pta = Vector([qa[i]*u0 + qc[i]*t0 for i in range(0, 3)])
            ptb = Vector([qa[i]*u1 + qc[i]*t1 for i in range(0, 3)])
            ptc = Vector([qb[i]*u0 + qd[i]*t0 for i in range(0, 3)])
            ptd = Vector([qb[i]*u1 + qd[i]*t1 for i in range(0, 3)])
    
            return Segment(pta, ptb, ptc, ptd)
    
        #see https://stackoverflow.com/questions/24809978/calculating-the-bounding-box-of-cubic-bezier-curve
        #(3 D - 9 C + 9 B - 3 A) t^2 + (6 A - 12 B + 6 C) t + 3 (B - A)
        #pts[0] - start, pts[1] - ctrl1, pts[2] - ctrl2, , pts[3] - end
        #TODO: Return Vectors to make world space calculations consistent
        def bbox(self, mw = None):
            def evalBez(AA, BB, CC, DD, t):
                return AA * (1 - t) * (1 - t) * (1 - t) + \
                        3 * BB * t * (1 - t) * (1 - t) + \
                            3 * CC * t * t * (1 - t) + \
                                DD * t * t * t
    
            A = self.start
            B = self.ctrl1
            C = self.ctrl2
            D = self.end
    
            if(mw != None):
                A = mw @ A
                B = mw @ B
                C = mw @ C
                D = mw @ D
    
            MINXYZ = [min([A[i], D[i]]) for i in range(0, 3)]
            MAXXYZ = [max([A[i], D[i]]) for i in range(0, 3)]
            leftBotBack_rgtTopFront = [MINXYZ, MAXXYZ]
    
            a = [3 * D[i] - 9 * C[i] + 9 * B[i] - 3 * A[i] for i in range(0, 3)]
            b = [6 * A[i] - 12 * B[i] + 6 * C[i] for i in range(0, 3)]
            c = [3 * (B[i] - A[i]) for i in range(0, 3)]
    
            solnsxyz = []
            for i in range(0, 3):
                solns = []
                if(a[i] == 0):
                    if(b[i] == 0):
                        solns.append(0)#Independent of t so lets take the starting pt
                    else:
                        solns.append(c[i] / b[i])
                else:
                    rootFact = b[i] * b[i] - 4 * a[i] * c[i]
                    if(rootFact >=0 ):
                        #Two solutions with + and - sqrt
                        solns.append((-b[i] + sqrt(rootFact)) / (2 * a[i]))
                        solns.append((-b[i] - sqrt(rootFact)) / (2 * a[i]))
                solnsxyz.append(solns)
    
            for i, soln in enumerate(solnsxyz):
                for j, t in enumerate(soln):
                    if(t < 1 and t > 0):
                        co = evalBez(A[i], B[i], C[i], D[i], t)
                        if(co < leftBotBack_rgtTopFront[0][i]):
                            leftBotBack_rgtTopFront[0][i] = co
                        if(co > leftBotBack_rgtTopFront[1][i]):
                            leftBotBack_rgtTopFront[1][i] = co
    
            return leftBotBack_rgtTopFront
    
    
    class Part():
        def __init__(self, parent, segs, isClosed):
            self.parent = parent
            self.segs = segs
    
            #use_cyclic_u
            self.isClosed = isClosed
    
            #Indicates if this should be closed based on its counterparts in other paths
            self.toClose = isClosed
    
            self.length = sum(seg.length for seg in self.segs)
            self.bbox = None
            self.bboxWorldSpace = None
    
        def getSeg(self, idx):
            return self.segs[idx]
    
        def getSegs(self):
            return self.segs
    
        def getSegsCopy(self, start, end):
            if(start == None):
                start = 0
            if(end == None):
                end = len(self.segs)
            return self.segs[start:end]
    
        def getBBox(self, worldSpace):
            #Avoid frequent calculations, as this will be called in compare method
            if(not worldSpace and self.bbox != None):
                return self.bbox
    
            if(worldSpace and self.bboxWorldSpace != None):
                return self.bboxWorldSpace
    
            leftBotBack_rgtTopFront = [[None]*3,[None]*3]
    
            for seg in self.segs:
    
                if(worldSpace):
                    bb = seg.bbox(self.parent.curve.matrix_world)
                else:
                    bb = seg.bbox()
    
                for i in range(0, 3):
                    if (leftBotBack_rgtTopFront[0][i] == None or \
                        bb[0][i] < leftBotBack_rgtTopFront[0][i]):
                        leftBotBack_rgtTopFront[0][i] = bb[0][i]
    
                for i in range(0, 3):
                    if (leftBotBack_rgtTopFront[1][i] == None or \
                        bb[1][i] > leftBotBack_rgtTopFront[1][i]):
                        leftBotBack_rgtTopFront[1][i] = bb[1][i]
    
            if(worldSpace):
                self.bboxWorldSpace = leftBotBack_rgtTopFront
            else:
                self.bbox = leftBotBack_rgtTopFront
    
            return leftBotBack_rgtTopFront
    
        #private
        def getBBDiff(self, axisIdx, worldSpace):
            obj = self.parent.curve
            bbox = self.getBBox(worldSpace)
            diff = abs(bbox[1][axisIdx] - bbox[0][axisIdx])
            return diff
    
        def getBBWidth(self, worldSpace):
            return self.getBBDiff(0, worldSpace)
    
        def getBBHeight(self, worldSpace):
            return self.getBBDiff(1, worldSpace)
    
        def getBBDepth(self, worldSpace):
            return self.getBBDiff(2, worldSpace)
    
        def bboxSurfaceArea(self, worldSpace):
            leftBotBack_rgtTopFront = self.getBBox(worldSpace)
            w = abs( leftBotBack_rgtTopFront[1][0] - leftBotBack_rgtTopFront[0][0] )
            l = abs( leftBotBack_rgtTopFront[1][1] - leftBotBack_rgtTopFront[0][1] )
            d = abs( leftBotBack_rgtTopFront[1][2] - leftBotBack_rgtTopFront[0][2] )
    
            return 2 * (w * l + w * d + l * d)
    
        def getSegCnt(self):
            return len(self.segs)
    
        def getBezierPtsInfo(self):
            prevSeg = None
            bezierPtsInfo = []
    
            for j, seg in enumerate(self.getSegs()):
    
                pt = seg.start
                handleRight = seg.ctrl1
    
                if(j == 0):
                    if(self.toClose):
                        handleLeft = self.getSeg(-1).ctrl2
                    else:
                        handleLeft = pt
                else:
                    handleLeft = prevSeg.ctrl2
    
                bezierPtsInfo.append([pt, handleLeft, handleRight])
                prevSeg = seg
    
            if(self.toClose == True):
                bezierPtsInfo[-1][2] = seg.ctrl1
            else:
                bezierPtsInfo.append([prevSeg.end, prevSeg.ctrl2, prevSeg.end])
    
            return bezierPtsInfo
    
        def __repr__(self):
            return str(self.length)
    
    
    class Path:
        def __init__(self, curve, objData = None, name = None):
    
            if(objData == None):
                objData = curve.data
    
            if(name == None):
                name = curve.name
    
            self.name = name
            self.curve = curve
    
            self.parts = [Part(self, getSplineSegs(s), s.use_cyclic_u) for s in objData.splines]
    
        def getPartCnt(self):
            return len(self.parts)
    
        def getPartView(self):
            p = Part(self, [seg for part in self.parts for seg in part.getSegs()], None)
            return p
    
        def getPartBoundaryIdxs(self):
            cumulCntList = set()
            cumulCnt = 0
    
            for p in self.parts:
                cumulCnt += p.getSegCnt()
                cumulCntList.add(cumulCnt)
    
            return cumulCntList
    
        def updatePartsList(self, segCntsPerPart, byPart):
            monolithicSegList = [seg for part in self.parts for seg in part.getSegs()]
            oldParts = self.parts[:]
            currPart = oldParts[0]
            partIdx = 0
            self.parts.clear()
    
            for i in range(0, len(segCntsPerPart)):
                if( i == 0):
                    currIdx = 0
                else:
                    currIdx = segCntsPerPart[i-1]
    
                nextIdx = segCntsPerPart[i]
                isClosed = False
    
                if(vectCmpWithMargin(monolithicSegList[currIdx].start, \
                        currPart.getSegs()[0].start) and \
                    vectCmpWithMargin(monolithicSegList[nextIdx-1].end, \
                        currPart.getSegs()[-1].end)):
                    isClosed = currPart.isClosed
    
                self.parts.append(Part(self, \
                    monolithicSegList[currIdx:nextIdx], isClosed))
    
                if(monolithicSegList[nextIdx-1] == currPart.getSegs()[-1]):
                    partIdx += 1
                    if(partIdx < len(oldParts)):
                        currPart = oldParts[partIdx]
    
        def getBezierPtsBySpline(self):
            data = []
    
            for i, part in enumerate(self.parts):
                data.append(part.getBezierPtsInfo())
    
            return data
    
        def getNewCurveData(self):
    
            newCurveData = self.curve.data.copy()
            newCurveData.splines.clear()
    
            splinesData = self.getBezierPtsBySpline()
    
            for i, newPoints in enumerate(splinesData):
    
                spline = newCurveData.splines.new('BEZIER')
                spline.bezier_points.add(len(newPoints)-1)
                spline.use_cyclic_u = self.parts[i].toClose
    
                for j in range(0, len(spline.bezier_points)):
                    newPoint = newPoints[j]
                    spline.bezier_points[j].co = newPoint[0]
                    spline.bezier_points[j].handle_left = newPoint[1]
                    spline.bezier_points[j].handle_right = newPoint[2]
                    spline.bezier_points[j].handle_right_type = 'FREE'
    
            return newCurveData
    
        def updateCurve(self):
            curveData = self.curve.data
            #Remove existing shape keys first
            if(curveData.shape_keys != None):
                keyblocks = reversed(curveData.shape_keys.key_blocks)
                for sk in keyblocks:
                    self.curve.shape_key_remove(sk)
            self.curve.data = self.getNewCurveData()
            bpy.data.curves.remove(curveData)
    
    def main(removeOriginal, space, matchParts, matchCriteria, alignBy, alignValues):
        targetObj = bpy.context.active_object
        if(targetObj == None or not isBezier(targetObj)):
            return
    
        target = Path(targetObj)
    
        shapekeys = [Path(c) for c in bpy.context.selected_objects if isBezier(c) \
            and c != bpy.context.active_object]
    
        if(len(shapekeys) == 0):
            return
    
        shapekeys = getExistingShapeKeyPaths(target) + shapekeys
        userSel = [target] + shapekeys
    
        for path in userSel:
            alignPath(path, matchParts, matchCriteria, alignBy, alignValues)
    
        addMissingSegs(userSel, byPart = (matchParts != "-None-"))
    
        bIdxs = set()
        for path in userSel:
            bIdxs = bIdxs.union(path.getPartBoundaryIdxs())
    
        for path in userSel:
            path.updatePartsList(sorted(list(bIdxs)), byPart = False)
    
        #All will have the same part count by now
        allToClose = [all(path.parts[j].isClosed for path in userSel)
            for j in range(0, len(userSel[0].parts))]
    
        #All paths will have the same no of splines with the same no of bezier points
        for path in userSel:
            for j, part in enumerate(path.parts):
                part.toClose = allToClose[j]
    
        target.updateCurve()
    
        target.curve.shape_key_add(name = 'Basis')
    
        addShapeKeys(target.curve, shapekeys, space)
    
        if(removeOriginal):
            for path in userSel:
                if(path.curve != target.curve):
                    safeRemoveCurveObj(path.curve)
    
        return {}
    
    def getSplineSegs(spline):
        p = spline.bezier_points
        segs = [Segment(p[i-1].co, p[i-1].handle_right, p[i].handle_left, p[i].co) \
            for i in range(1, len(p))]
        if(spline.use_cyclic_u):
            segs.append(Segment(p[-1].co, p[-1].handle_right, p[0].handle_left, p[0].co))
        return segs
    
    def subdivideSeg(origSeg, noSegs):
        if(noSegs < 2):
            return [origSeg]
    
        segs = []
        oldT = 0
        segLen = origSeg.length / noSegs
    
        for i in range(0, noSegs-1):
            t = float(i+1) / noSegs
            seg = origSeg.partialSeg(oldT, t)
            segs.append(seg)
            oldT = t
    
        seg = origSeg.partialSeg(oldT, 1)
        segs.append(seg)
    
        return segs
    
    
    def getSubdivCntPerSeg(part, toAddCnt):
    
        class SegWrapper:
            def __init__(self, idx, seg):
                self.idx = idx
                self.seg = seg
                self.length = seg.length
    
        class PartWrapper:
            def __init__(self, part):
                self.segList = []
                self.segCnt = len(part.getSegs())
                for idx, seg in enumerate(part.getSegs()):
                    self.segList.append(SegWrapper(idx, seg))
    
        partWrapper = PartWrapper(part)
        partLen = part.length
        avgLen = partLen / (partWrapper.segCnt + toAddCnt)
    
        segsToDivide = [sr for sr in partWrapper.segList if sr.seg.length >= avgLen]
        segToDivideCnt = len(segsToDivide)
        avgLen = sum(sr.seg.length for sr in segsToDivide) / (segToDivideCnt + toAddCnt)
    
        segsToDivide = sorted(segsToDivide, key=lambda x: x.length, reverse = True)
    
        cnts = [0] * partWrapper.segCnt
        addedCnt = 0
    
    
        for i in range(0, segToDivideCnt):
            segLen = segsToDivide[i].seg.length
    
            divideCnt = int(round(segLen/avgLen)) - 1
            if(divideCnt == 0):
                break
    
            if((addedCnt + divideCnt) >= toAddCnt):
                cnts[segsToDivide[i].idx] = toAddCnt - addedCnt
                addedCnt = toAddCnt
                break
    
            cnts[segsToDivide[i].idx] = divideCnt
    
            addedCnt += divideCnt
    
        #TODO: Verify if needed
        while(toAddCnt > addedCnt):
            for i in range(0, segToDivideCnt):
                cnts[segsToDivide[i].idx] += 1
                addedCnt += 1
                if(toAddCnt == addedCnt):
                    break
    
        return cnts
    
    #Just distribute equally; this is likely a rare condition. So why complicate?
    def distributeCnt(maxSegCntsByPart, startIdx, extraCnt):
        added = 0
        elemCnt = len(maxSegCntsByPart) - startIdx
        cntPerElem = floor(extraCnt / elemCnt)
        remainder = extraCnt % elemCnt
    
        for i in range(startIdx, len(maxSegCntsByPart)):
            maxSegCntsByPart[i] += cntPerElem
            if(i < remainder + startIdx):
                maxSegCntsByPart[i] += 1
    
    #Make all the paths to have the maximum number of segments in the set
    #TODO: Refactor
    def addMissingSegs(selPaths, byPart):
        maxSegCntsByPart = []
        maxSegCnt = 0
    
        resSegCnt = []
        sortedPaths = sorted(selPaths, key = lambda c: -len(c.parts))
    
        for i, path in enumerate(sortedPaths):
            if(byPart == False):
                segCnt = path.getPartView().getSegCnt()
                if(segCnt > maxSegCnt):
                    maxSegCnt = segCnt
            else:
                resSegCnt.append([])
                for j, part in enumerate(path.parts):
                    partSegCnt = part.getSegCnt()
                    resSegCnt[i].append(partSegCnt)
    
                    #First path
                    if(j == len(maxSegCntsByPart)):
                        maxSegCntsByPart.append(partSegCnt)
    
                    #last part of this path, but other paths in set have more parts
                    elif((j == len(path.parts) - 1) and
                        len(maxSegCntsByPart) > len(path.parts)):
    
                        remainingSegs = sum(maxSegCntsByPart[j:])
                        if(partSegCnt <= remainingSegs):
                            resSegCnt[i][j] = remainingSegs
                        else:
                            #This part has more segs than the sum of the remaining part segs
                            #So distribute the extra count
                            distributeCnt(maxSegCntsByPart, j, (partSegCnt - remainingSegs))
    
                            #Also, adjust the seg count of the last part of the previous
                            #segments that had fewer than max number of parts
                            for k in range(0, i):
                                if(len(sortedPaths[k].parts) < len(maxSegCntsByPart)):
                                    totalSegs = sum(maxSegCntsByPart)
                                    existingSegs = sum(maxSegCntsByPart[:len(sortedPaths[k].parts)-1])
                                    resSegCnt[k][-1] = totalSegs - existingSegs
    
                    elif(partSegCnt > maxSegCntsByPart[j]):
                        maxSegCntsByPart[j] = partSegCnt
        for i, path in enumerate(sortedPaths):
    
            if(byPart == False):
                partView = path.getPartView()
                segCnt = partView.getSegCnt()
                diff = maxSegCnt - segCnt
    
                if(diff > 0):
                    cnts = getSubdivCntPerSeg(partView, diff)
                    cumulSegIdx = 0
                    for j in range(0, len(path.parts)):
                        part = path.parts[j]
                        newSegs = []
                        for k, seg in enumerate(part.getSegs()):
                            numSubdivs = cnts[cumulSegIdx] + 1
                            newSegs += subdivideSeg(seg, numSubdivs)
                            cumulSegIdx += 1
    
                        path.parts[j] = Part(path, newSegs, part.isClosed)
            else:
                for j in range(0, len(path.parts)):
                    part = path.parts[j]
                    newSegs = []
    
                    partSegCnt = part.getSegCnt()
    
                    #TODO: Adding everything in the last part?
                    if(j == (len(path.parts)-1) and
                        len(maxSegCntsByPart) > len(path.parts)):
                        diff = resSegCnt[i][j] - partSegCnt
                    else:
                        diff = maxSegCntsByPart[j] - partSegCnt
    
                    if(diff > 0):
                        cnts = getSubdivCntPerSeg(part, diff)
    
                        for k, seg in enumerate(part.getSegs()):
                            seg = part.getSeg(k)
                            subdivCnt = cnts[k] + 1 #1 for the existing one
                            newSegs += subdivideSeg(seg, subdivCnt)
    
                        #isClosed won't be used, but let's update anyway
                        path.parts[j] = Part(path, newSegs, part.isClosed)
    
    #TODO: Simplify (Not very readable)
    def alignPath(path, matchParts, matchCriteria, alignBy, alignValues):
    
        parts = path.parts[:]
    
        if(matchParts == 'custom'):
            fnMap = {'vCnt' : lambda part: -1 * part.getSegCnt(), \
                     'bbArea': lambda part: -1 * part.bboxSurfaceArea(worldSpace = True), \
                     'bbHeight' : lambda part: -1 * part.getBBHeight(worldSpace = True), \
                     'bbWidth' : lambda part: -1 * part.getBBWidth(worldSpace = True), \
                     'bbDepth' : lambda part: -1 * part.getBBDepth(worldSpace = True)
                    }
            matchPartCmpFns = []
            for criterion in matchCriteria:
                fn = fnMap.get(criterion)
                if(fn == None):
                    minmax = criterion[:3] == 'max' #0 if min; 1 if max
                    axisIdx = ord(criterion[3:]) - ord('X')
    
                    fn = eval('lambda part: part.getBBox(worldSpace = True)[' + \
                        str(minmax) + '][' + str(axisIdx) + ']')
    
                matchPartCmpFns.append(fn)
    
            def comparer(left, right):
                for fn in matchPartCmpFns:
                    a = fn(left)
                    b = fn(right)
    
                    if(floatCmpWithMargin(a, b)):
                        continue
                    else:
                        return (a > b) - ( a < b) #No cmp in python3
    
                return 0
    
            parts = sorted(parts, key = cmp_to_key(comparer))
    
        alignCmpFn = None
        if(alignBy == 'vertCo'):
            def evalCmp(criteria, pt1, pt2):
                if(len(criteria) == 0):
                    return True
    
                minmax = criteria[0][0]
                axisIdx = criteria[0][1]
                val1 = pt1[axisIdx]
                val2 = pt2[axisIdx]
    
                if(floatCmpWithMargin(val1, val2)):
                    criteria = criteria[:]
                    criteria.pop(0)
                    return evalCmp(criteria, pt1, pt2)
    
                return val1 < val2 if minmax == 'min' else val1 > val2
    
            alignCri = [[a[:3], ord(a[3:]) - ord('X')] for a in alignValues]
            alignCmpFn = lambda pt1, pt2, curve: (evalCmp(alignCri, \
                curve.matrix_world @ pt1, curve.matrix_world @ pt2))
    
        startPt = None
        startIdx = None
    
        for i in range(0, len(parts)):
            #Only truly closed parts
            if(alignCmpFn != None and parts[i].isClosed):
                for j in range(0, parts[i].getSegCnt()):
                    seg = parts[i].getSeg(j)
                    if(j == 0 or alignCmpFn(seg.start, startPt, path.curve)):
                        startPt = seg.start
                        startIdx = j
    
                path.parts[i]= Part(path, parts[i].getSegsCopy(startIdx, None) + \
                    parts[i].getSegsCopy(None, startIdx), parts[i].isClosed)
            else:
                path.parts[i] = parts[i]
    
    #TODO: Other shape key attributes like interpolation...?
    def getExistingShapeKeyPaths(path):
        obj = path.curve
        paths = []
    
        if(obj.data.shape_keys != None):
            keyblocks = obj.data.shape_keys.key_blocks[1:]#Skip basis
            for key in keyblocks:
                datacopy = obj.data.copy()
                i = 0
                for spline in datacopy.splines:
                    for pt in spline.bezier_points:
                        pt.co = key.data[i].co
                        pt.handle_left = key.data[i].handle_left
                        pt.handle_right = key.data[i].handle_right
                        i += 1
                paths.append(Path(obj, datacopy, key.name))
        return paths
    
    def addShapeKeys(curve, paths, space):
        for path in paths:
            key = curve.shape_key_add(name = path.name)
            pts = [pt for pset in path.getBezierPtsBySpline() for pt in pset]
            for i, pt in enumerate(pts):
                if(space == 'worldspace'):
                    pt = [curve.matrix_world.inverted() @ (path.curve.matrix_world @ p) for p in pt]
                key.data[i].co = pt[0]
                key.data[i].handle_left = pt[1]
                key.data[i].handle_right = pt[2]
    
    #TODO: Remove try
    def safeRemoveCurveObj(obj):
        try:
            collections = obj.users_collection
    
            for c in collections:
                c.objects.unlink(obj)
    
            if(obj.name in bpy.context.scene.collection.objects):
                bpy.context.scene.collection.objects.unlink(obj)
    
            if(obj.data.users == 1):
                if(obj.type == 'CURVE'):
                    bpy.data.curves.remove(obj.data) #This also removes object?
                elif(obj.type == 'MESH'):
                    bpy.data.meshes.remove(obj.data)
    
            bpy.data.objects.remove(obj)
        except:
            pass
    
    
    def markVertHandler(self, context):
        if(self.markVertex):
            bpy.ops.wm.mark_vertex()
    
    
    #################### UI and Registration ####################
    
    
        bl_idname = "object.assign_shape_keys"
        bl_label = "Assign Shape Keys"
        bl_options = {'REGISTER', 'UNDO'}
    
        def execute(self, context):
            params = context.window_manager.AssignShapeKeyParams
            removeOriginal = params.removeOriginal
            space = params.space
    
            matchParts = params.matchParts
            matchCri1 = params.matchCri1
            matchCri2 = params.matchCri2
            matchCri3 = params.matchCri3
    
            alignBy = params.alignList
            alignVal1 = params.alignVal1
            alignVal2 = params.alignVal2
            alignVal3 = params.alignVal3
    
            createdObjsMap = main(removeOriginal, space, \
                                matchParts, [matchCri1, matchCri2, matchCri3], \
                                    alignBy, [alignVal1, alignVal2, alignVal3])
    
            return {'FINISHED'}
    
    
    class MarkerController:
        drawHandlerRef = None
        defPointSize = 6
        ptColor = (0, .8, .8, 1)
    
        def createSMMap(self, context):
            objs = context.selected_objects
            smMap = {}
            for curve in objs:
                if(not isBezier(curve)):
                    continue
    
                smMap[curve.name] = {}
                mw = curve.matrix_world
                for splineIdx, spline in enumerate(curve.data.splines):
                    if(not spline.use_cyclic_u):
                        continue
    
                    #initialize to the curr start vert co and idx
                    smMap[curve.name][splineIdx] = \
                        [mw @ curve.data.splines[splineIdx].bezier_points[0].co, 0]
    
                    for pt in spline.bezier_points:
                        pt.select_control_point = False
    
                if(len(smMap[curve.name]) == 0):
                    del smMap[curve.name]
    
            return smMap
    
        def createBatch(self, context):
            positions = [s[0] for cn in self.smMap.values() for s in cn.values()]
            colors = [MarkerController.ptColor for i in range(0, len(positions))]
    
            self.batch = batch_for_shader(self.shader, \
                "POINTS", {"pos": positions, "color": colors})
    
            if context.area:
                context.area.tag_redraw()
    
        def drawHandler(self):
            bgl.glPointSize(MarkerController.defPointSize)
            self.batch.draw(self.shader)
    
        def removeMarkers(self, context):
            if(MarkerController.drawHandlerRef != None):
                bpy.types.SpaceView3D.draw_handler_remove(MarkerController.drawHandlerRef, \
                    "WINDOW")
    
                if(context.area and hasattr(context.space_data, 'region_3d')):
                    context.area.tag_redraw()
    
                MarkerController.drawHandlerRef = None
    
            self.deselectAll()
    
        def __init__(self, context):
            self.smMap = self.createSMMap(context)
            self.shader = gpu.shader.from_builtin('3D_FLAT_COLOR')
            self.shader.bind()
    
            MarkerController.drawHandlerRef = \
                bpy.types.SpaceView3D.draw_handler_add(self.drawHandler, \
                    (), "WINDOW", "POST_VIEW")
    
            self.createBatch(context)
    
        def saveStartVerts(self):
            for curveName in self.smMap.keys():
                curve = bpy.data.objects[curveName]
                splines = curve.data.splines
                spMap = self.smMap[curveName]
    
                for splineIdx in spMap.keys():
                    markerInfo = spMap[splineIdx]
                    if(markerInfo[1] != 0):
                        pts = splines[splineIdx].bezier_points
                        loc, idx = markerInfo[0], markerInfo[1]
                        cnt = len(pts)
    
                        ptCopy = [[p.co.copy(), p.handle_right.copy(), \
                            p.handle_left.copy(), p.handle_right_type, \
                                p.handle_left_type] for p in pts]
    
                        for i, pt in enumerate(pts):
                            srcIdx = (idx + i) % cnt
                            p = ptCopy[srcIdx]
    
                            #Must set the types first
                            pt.handle_right_type = p[3]
                            pt.handle_left_type = p[4]
                            pt.co = p[0]
                            pt.handle_right = p[1]
                            pt.handle_left = p[2]
    
        def updateSMMap(self):
            for curveName in self.smMap.keys():
                curve = bpy.data.objects[curveName]
                spMap = self.smMap[curveName]
                mw = curve.matrix_world
    
                for splineIdx in spMap.keys():
                    markerInfo = spMap[splineIdx]
                    loc, idx = markerInfo[0], markerInfo[1]
                    pts = curve.data.splines[splineIdx].bezier_points
    
                    selIdxs = [x for x in range(0, len(pts)) \
                        if pts[x].select_control_point == True]
    
                    selIdx = selIdxs[0] if(len(selIdxs) > 0 ) else idx
                    co = mw @ pts[selIdx].co
                    self.smMap[curveName][splineIdx] = [co, selIdx]
    
        def deselectAll(self):
            for curveName in self.smMap.keys():
                curve = bpy.data.objects[curveName]
                for spline in curve.data.splines:
                    for pt in spline.bezier_points:
                        pt.select_control_point = False
    
        def getSpaces3D(context):
            areas3d  = [area for area in context.window.screen.areas \
                if area.type == 'VIEW_3D']
    
            return [s for a in areas3d for s in a.spaces if s.type == 'VIEW_3D']
    
        def hideHandles(context):
            states = []
            spaces = MarkerController.getSpaces3D(context)
            for s in spaces:
                states.append(s.overlay.show_curve_handles)
                s.overlay.show_curve_handles = False
            return states
    
        def resetShowHandleState(context, handleStates):
            spaces = MarkerController.getSpaces3D(context)
            for i, s in enumerate(spaces):
                s.overlay.show_curve_handles = handleStates[i]
    
    
    
        bl_description = "Mark Vertex"
        bl_idname = "wm.mark_vertex"
        bl_label = "Mark Start Vertex"
    
        def cleanup(self, context):
            wm = context.window_manager
            wm.event_timer_remove(self._timer)
            self.markerState.removeMarkers(context)
            MarkerController.resetShowHandleState(context, self.handleStates)
            context.window_manager.AssignShapeKeyParams.markVertex = False
    
        def modal (self, context, event):
            params = context.window_manager.AssignShapeKeyParams
    
            if(context.mode  == 'OBJECT' or event.type == "ESC" or\
                not context.window_manager.AssignShapeKeyParams.markVertex):
                self.cleanup(context)
                return {'CANCELLED'}
    
            elif(event.type == "RET"):
                self.markerState.saveStartVerts()
                self.cleanup(context)
                return {'FINISHED'}
    
            if(event.type == 'TIMER'):
                self.markerState.updateSMMap()
                self.markerState.createBatch(context)
    
            elif(event.type in {'LEFT_CTRL', 'RIGHT_CTRL'}):
                self.ctrl = (event.value == 'PRESS')
    
            elif(event.type in {'LEFT_SHIFT', 'RIGHT_SHIFT'}):
                self.shift = (event.value == 'PRESS')
    
                if(event.type not in {"MIDDLEMOUSE", "TAB", "LEFTMOUSE", \
                    "RIGHTMOUSE", 'WHEELDOWNMOUSE', 'WHEELUPMOUSE'} and \
                    not event.type.startswith("NUMPAD_")):
                    return {'RUNNING_MODAL'}
    
            return {"PASS_THROUGH"}
    
        def execute(self, context):
            #TODO: Why such small step?
            self._timer = context.window_manager.event_timer_add(time_step = 0.0001, \
                window = context.window)
            self.ctrl = False
            self.shift = False
    
            context.window_manager.modal_handler_add(self)
            self.markerState = MarkerController(context)
    
            #Hide so that users don't accidentally select handles instead of points
            self.handleStates = MarkerController.hideHandles(context)
    
            return {"RUNNING_MODAL"}
    
    
    class AssignShapeKeyParams(bpy.types.PropertyGroup):
    
        removeOriginal : BoolProperty(name = "Remove Shape Key Objects", \
            description = "Remove shape key objects after assigning to target", \
                default = True)
    
        space : EnumProperty(name = "Space", \
            items = [('worldspace', 'World Space', 'worldspace'),
                     ('localspace', 'Local Space', 'localspace')], \
            description = 'Space that shape keys are evluated in')