dia2django.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # -*- coding: utf-8 -*-
  2. """
  3. Author Igor Támara igor@tamarapatino.org
  4. Use this little program as you wish, if you
  5. include it in your work, let others know you
  6. are using it preserving this note, you have
  7. the right to make derivative works, Use it
  8. at your own risk.
  9. Tested to work on(etch testing 13-08-2007):
  10. Python 2.4.4 (#2, Jul 17 2007, 11:56:54)
  11. [GCC 4.1.3 20070629 (prerelease) (Debian 4.1.2-13)] on linux2
  12. """
  13. import codecs
  14. import gzip
  15. import re
  16. import sys
  17. from xml.dom.minidom import Node, parseString
  18. import six
  19. dependclasses = ["User", "Group", "Permission", "Message"]
  20. # Type dictionary translation types SQL -> Django
  21. tsd = {
  22. "text": "TextField",
  23. "date": "DateField",
  24. "varchar": "CharField",
  25. "int": "IntegerField",
  26. "float": "FloatField",
  27. "serial": "AutoField",
  28. "boolean": "BooleanField",
  29. "numeric": "FloatField",
  30. "timestamp": "DateTimeField",
  31. "bigint": "IntegerField",
  32. "datetime": "DateTimeField",
  33. "time": "TimeField",
  34. "bool": "BooleanField",
  35. }
  36. # convert varchar -> CharField
  37. v2c = re.compile(r'varchar\((\d+)\)')
  38. def find_index(fks, id_):
  39. """
  40. Look for the id on fks, fks is an array of arrays, each array has on [1]
  41. the id of the class in a dia diagram. When not present returns None, else
  42. it returns the position of the class with id on fks
  43. """
  44. for i, _ in fks.items():
  45. if fks[i][1] == id_:
  46. return i
  47. return None
  48. def addparentstofks(rels, fks):
  49. """
  50. Get a list of relations, between parents and sons and a dict of
  51. clases named in dia, and modifies the fks to add the parent as fk to get
  52. order on the output of classes and replaces the base class of the son, to
  53. put the class parent name.
  54. """
  55. for j in rels:
  56. son = find_index(fks, j[1])
  57. parent = find_index(fks, j[0])
  58. fks[son][2] = fks[son][2].replace("models.Model", parent)
  59. if parent not in fks[son][0]:
  60. fks[son][0].append(parent)
  61. def dia2django(archivo):
  62. models_txt = ''
  63. f = codecs.open(archivo, "rb")
  64. # dia files are gzipped
  65. data = gzip.GzipFile(fileobj=f).read()
  66. ppal = parseString(data)
  67. # diagram -> layer -> object -> UML - Class -> name, (attribs : composite -> name,type)
  68. datos = ppal.getElementsByTagName("dia:diagram")[0].getElementsByTagName("dia:layer")[0].getElementsByTagName("dia:object")
  69. clases = {}
  70. herit = []
  71. imports = six.u("")
  72. for i in datos:
  73. # Look for the classes
  74. if i.getAttribute("type") == "UML - Class":
  75. myid = i.getAttribute("id")
  76. for j in i.childNodes:
  77. if j.nodeType == Node.ELEMENT_NODE and j.hasAttributes():
  78. if j.getAttribute("name") == "name":
  79. actclas = j.getElementsByTagName("dia:string")[0].childNodes[0].data[1:-1]
  80. myname = "\nclass %s(models.Model) :\n" % actclas
  81. clases[actclas] = [[], myid, myname, 0]
  82. if j.getAttribute("name") == "attributes":
  83. for ll in j.getElementsByTagName("dia:composite"):
  84. if ll.getAttribute("type") == "umlattribute":
  85. # Look for the attribute name and type
  86. for k in ll.getElementsByTagName("dia:attribute"):
  87. if k.getAttribute("name") == "name":
  88. nc = k.getElementsByTagName("dia:string")[0].childNodes[0].data[1:-1]
  89. elif k.getAttribute("name") == "type":
  90. tc = k.getElementsByTagName("dia:string")[0].childNodes[0].data[1:-1]
  91. elif k.getAttribute("name") == "value":
  92. val = k.getElementsByTagName("dia:string")[0].childNodes[0].data[1:-1]
  93. if val == '##':
  94. val = ''
  95. elif k.getAttribute("name") == "visibility" and k.getElementsByTagName("dia:enum")[0].getAttribute("val") == "2":
  96. if tc.replace(" ", "").lower().startswith("manytomanyfield("):
  97. # If we find a class not in our model that is marked as being to another model
  98. newc = tc.replace(" ", "")[16:-1]
  99. if dependclasses.count(newc) == 0:
  100. dependclasses.append(newc)
  101. if tc.replace(" ", "").lower().startswith("foreignkey("):
  102. # If we find a class not in our model that is marked as being to another model
  103. newc = tc.replace(" ", "")[11:-1]
  104. if dependclasses.count(newc) == 0:
  105. dependclasses.append(newc)
  106. # Mapping SQL types to Django
  107. varch = v2c.search(tc)
  108. if tc.replace(" ", "").startswith("ManyToManyField("):
  109. myfor = tc.replace(" ", "")[16:-1]
  110. if actclas == myfor:
  111. # In case of a recursive type, we use 'self'
  112. tc = tc.replace(myfor, "'self'")
  113. elif clases[actclas][0].count(myfor) == 0:
  114. # Adding related class
  115. if myfor not in dependclasses:
  116. # In case we are using Auth classes or external via protected dia visibility
  117. clases[actclas][0].append(myfor)
  118. tc = "models." + tc
  119. if len(val) > 0:
  120. tc = tc.replace(")", "," + val + ")")
  121. elif tc.find("Field") != -1:
  122. if tc.count("()") > 0 and len(val) > 0:
  123. tc = "models.%s" % tc.replace(")", "," + val + ")")
  124. else:
  125. tc = "models.%s(%s)" % (tc, val)
  126. elif tc.replace(" ", "").startswith("ForeignKey("):
  127. myfor = tc.replace(" ", "")[11:-1]
  128. if actclas == myfor:
  129. # In case of a recursive type, we use 'self'
  130. tc = tc.replace(myfor, "'self'")
  131. elif clases[actclas][0].count(myfor) == 0:
  132. # Adding foreign classes
  133. if myfor not in dependclasses:
  134. # In case we are using Auth classes
  135. clases[actclas][0].append(myfor)
  136. tc = "models." + tc
  137. if len(val) > 0:
  138. tc = tc.replace(")", "," + val + ")")
  139. elif varch is None:
  140. tc = "models." + tsd[tc.strip().lower()] + "(" + val + ")"
  141. else:
  142. tc = "models.CharField(max_length=" + varch.group(1) + ")"
  143. if len(val) > 0:
  144. tc = tc.replace(")", ", " + val + " )")
  145. if not (nc == "id" and tc == "AutoField()"):
  146. clases[actclas][2] += " %s = %s\n" % (nc, tc)
  147. elif i.getAttribute("type") == "UML - Generalization":
  148. mycons = ['A', 'A']
  149. a = i.getElementsByTagName("dia:connection")
  150. for j in a:
  151. if len(j.getAttribute("to")):
  152. mycons[int(j.getAttribute("handle"))] = j.getAttribute("to")
  153. print(mycons)
  154. if 'A' not in mycons:
  155. herit.append(mycons)
  156. elif i.getAttribute("type") == "UML - SmallPackage":
  157. a = i.getElementsByTagName("dia:string")
  158. for j in a:
  159. if len(j.childNodes[0].data[1:-1]):
  160. imports += six.u("from %s.models import *" % j.childNodes[0].data[1:-1])
  161. addparentstofks(herit, clases)
  162. # Ordering the appearance of classes
  163. # First we make a list of the classes each classs is related to.
  164. ordered = []
  165. for j, k in six.iteritems(clases):
  166. k[2] += "\n def __str__(self):\n return u\"\"\n"
  167. for fk in k[0]:
  168. if fk not in dependclasses:
  169. clases[fk][3] += 1
  170. ordered.append([j] + k)
  171. i = 0
  172. while i < len(ordered):
  173. mark = i
  174. j = i + 1
  175. while j < len(ordered):
  176. if ordered[i][0] in ordered[j][1]:
  177. mark = j
  178. j += 1
  179. if mark == i:
  180. i += 1
  181. else:
  182. # swap %s in %s" % ( ordered[i] , ordered[mark]) to make ordered[i] to be at the end
  183. if ordered[i][0] in ordered[mark][1] and ordered[mark][0] in ordered[i][1]:
  184. # Resolving simplistic circular ForeignKeys
  185. print("Not able to resolve circular ForeignKeys between %s and %s" % (ordered[i][1], ordered[mark][0]))
  186. break
  187. a = ordered[i]
  188. ordered[i] = ordered[mark]
  189. ordered[mark] = a
  190. if i == len(ordered) - 1:
  191. break
  192. ordered.reverse()
  193. if imports:
  194. models_txt = str(imports)
  195. for i in ordered:
  196. models_txt += '%s\n' % str(i[3])
  197. return models_txt
  198. if __name__ == '__main__':
  199. if len(sys.argv) == 2:
  200. dia2django(sys.argv[1])
  201. else:
  202. print(" Use:\n \n " + sys.argv[0] + " diagram.dia\n\n")