drawer.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description:
  5. """
  6. import pandas as pd
  7. from . import latlng
  8. def _base_input_check(locations):
  9. from .exceptions import InputTypeNotSuportException
  10. if not isinstance(locations, pd.DataFrame):
  11. raise InputTypeNotSuportException(InputTypeNotSuportException.input_type)
  12. if "省" not in locations.columns or "市" not in locations.columns \
  13. or "区" not in locations.columns:
  14. raise InputTypeNotSuportException(InputTypeNotSuportException.input_type)
  15. _lnglat = dict([(item[0], tuple(reversed(item[1]))) for item in latlng.items()])
  16. def _geo_update(geo):
  17. geo._coordinates = _lnglat
  18. def draw_locations(locations, file_path):
  19. """
  20. 基于folium生成地域分布的热力图的html文件.
  21. :param locations: 样本的省市区, pandas的dataframe类型.
  22. :param file_path: 生成的html文件的路径.
  23. """
  24. _base_input_check(locations)
  25. import folium
  26. from folium.plugins import HeatMap
  27. # 注意判断key是否存在
  28. heatData = []
  29. for map_key in zip(locations["省"], locations["市"], locations["区"]):
  30. if latlng.get(map_key):
  31. lat_lon = latlng.get(map_key)
  32. if lat_lon[0] and lat_lon[1]:
  33. heatData.append([float(lat_lon[0]), float(lat_lon[1]), 1])
  34. else:
  35. print('no latlng:', map_key, lat_lon)
  36. # 绘制Map,开始缩放程度是5倍
  37. map_osm = folium.Map(location=[35, 110], zoom_start=5)
  38. # 将热力图添加到前面建立的map里
  39. HeatMap(heatData).add_to(map_osm)
  40. # 保存为html文件
  41. map_osm.save(file_path)
  42. def echarts_draw(locations, file_path, title="地域分布图", subtitle="location distribute"):
  43. """
  44. 生成地域分布的echarts热力图的html文件.
  45. :param locations: 样本的省市区, pandas的dataframe类型.
  46. :param file_path: 生成的html文件路径.
  47. :param title: 图表的标题
  48. :param subtitle: 图表的子标题
  49. """
  50. from pyecharts import Geo
  51. _base_input_check(locations)
  52. count_map = {}
  53. for map_key in zip(locations["省"], locations["市"], locations["区"]):
  54. if latlng.get(map_key):
  55. count_map[map_key] = count_map.get(map_key, 0) + 1
  56. geo = Geo(title, subtitle, title_color="#fff",
  57. title_pos="center", width=1200,
  58. height=600, background_color='#404a59')
  59. _geo_update(geo)
  60. attr, value = geo.cast(count_map)
  61. geo.add("", attr, value, type="heatmap", is_visualmap=True,
  62. visual_text_color='#fff',
  63. is_piecewise=True, visual_split_number=10)
  64. geo.render(file_path)
  65. def echarts_cate_draw(locations, labels, file_path, title="地域分布图", subtitle="location distribute",
  66. point_size=7):
  67. """
  68. 依据分类生成地域分布的echarts散点图的html文件.
  69. :param locations: 样本的省市区, pandas的dataframe类型.
  70. :param labels: 长度必须和locations相等, 代表每个样本所属的分类.
  71. :param file_path: 生成的html文件路径.
  72. :param title: 图表的标题
  73. :param subtitle: 图表的子标题
  74. :param point_size: 每个散点的大小
  75. """
  76. _base_input_check(locations)
  77. if len(locations) != len(labels):
  78. from .exceptions import CPCAException
  79. raise CPCAException("locations的长度与labels长度必须相等")
  80. from pyecharts import Geo
  81. geo = Geo(title, subtitle, title_color="#000000",
  82. title_pos="center", width=1200,
  83. height=600, background_color='#fff')
  84. _geo_update(geo)
  85. uniques = set(list(labels))
  86. def _data_add(_geo, _cate_keys, _category):
  87. real_keys = []
  88. for cate_key in _cate_keys:
  89. if latlng.get(cate_key):
  90. real_keys.append(cate_key)
  91. attr = real_keys
  92. value = [1] * len(real_keys)
  93. geo.add(_category, attr, value, symbol_size=point_size,
  94. legend_pos="left", legend_top="bottom",
  95. geo_normal_color="#fff",
  96. geo_emphasis_color=" #f0f0f5")
  97. for category in uniques:
  98. cate_locations = locations[labels == category]
  99. _data_add(geo, zip(cate_locations["省"], cate_locations["市"],
  100. cate_locations["区"]), category)
  101. geo.render(file_path)