batch_task.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import asyncio
  2. import threading
  3. from app.admin.crud.crud_intent_org import intent_org_dao
  4. from app.admin.schema.intent_org import CurrentIntentOrgIns
  5. from batch_task.update_llm_intent import update_llm_intent, process_llm_intent
  6. from batch_task.update_mismatch_record import update_mismatch_record, process_mismatch_record
  7. from common.log import log
  8. from core.conf import settings
  9. from database.db_mysql import async_db_session
  10. from utils.serializers import select_as_dict
  11. batch_task_event = threading.Event()
  12. async def periodically_execute(interval, func, *args, **kwargs):
  13. while True:
  14. await func(*args, **kwargs)
  15. await asyncio.sleep(interval)
  16. async def execute_task():
  17. org_map = {}
  18. async with async_db_session.begin() as db:
  19. orgs = await intent_org_dao.get_all(db)
  20. if orgs:
  21. for org in orgs:
  22. if org.status == 1:
  23. org_map[org.id] = CurrentIntentOrgIns(**select_as_dict(org))
  24. batch_task_event.set()
  25. workers = settings.BATCH_CONCURRENT
  26. log.info(f"Starting task with {workers} workers.")
  27. # average_workers = int(round(workers / 2))
  28. while batch_task_event.is_set():
  29. len_llm_intent = await update_llm_intent(org_map, workers)
  30. len_mismatch_record = await update_mismatch_record(org_map, workers)
  31. if len_llm_intent or len_mismatch_record:
  32. log.info("Task executed")
  33. else:
  34. log.info("Task executed. Sleeping for 10 seconds...")
  35. await asyncio.sleep(10)
  36. def start_batch_task():
  37. global batch_task_event
  38. if not batch_task_event.is_set():
  39. log.info("Starting task.")
  40. batch_task_event.set()
  41. asyncio.run(execute_task())
  42. # asyncio.create_task(execute_task())
  43. def stop_batch_task():
  44. global batch_task_event
  45. # 清除事件
  46. batch_task_event.clear()
  47. log.info("Stop task.")